"""Vectorization base classes and basic functionality.
"""
__author__ = 'Paul Landes'
from typing import Tuple, Any, Union, ClassVar
from dataclasses import dataclass, field
from abc import abstractmethod, ABCMeta
import logging
import sys
import warnings
from io import TextIOBase
from scipy import sparse
from scipy.sparse import csr_matrix
import torch
from torch import Tensor
from zensols.persist import PersistableContainer
from zensols.config import ConfigFactory, Writable
from zensols.deeplearn import DeepLearnError, TorchConfig
logger = logging.getLogger(__name__)
[docs]
class VectorizerError(DeepLearnError):
"""Thrown by instances of :class:`.FeatureVectorizer` during encoding or
decoding operations.
"""
pass
[docs]
@dataclass
class ConfigurableVectorization(PersistableContainer, Writable):
name: str = field()
"""The name of the section given in the configuration.
"""
config_factory: ConfigFactory = field(repr=False)
"""The configuration factory that created this instance and used for
serialization functions.
"""
def __post_init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def write(self, depth: int = 0, writer: TextIOBase = sys.stdout):
self._write_line(f'name: {self.name}', depth, writer)
[docs]
@dataclass
class FeatureVectorizer(ConfigurableVectorization, metaclass=ABCMeta):
"""An asbstrct base class that transforms a Python object in to a PyTorch
tensor.
"""
feature_id: str = field()
"""Uniquely identifies this vectorizer."""
def _allow_config_adds(self) -> bool:
return True
@abstractmethod
def _get_shape(self) -> Tuple[int, int]:
pass
@property
def shape(self) -> Tuple[int, int]:
"""Return the shape of the tensor created by ``transform``.
"""
return self._get_shape()
@property
def description(self) -> str:
"""A short human readable name.
:see: obj:`feature_id`
"""
return self.DESCRIPTION
[docs]
def write(self, depth: int = 0, writer: TextIOBase = sys.stdout):
super().write(depth, writer)
self._write_line(f'description: {self.description}', depth, writer)
self._write_line(f'feature_id: {self.feature_id}', depth, writer)
self._write_line(f'shape: {self.shape}', depth, writer)
def __str__(self):
return (f'{self.feature_id} ({type(self)}), ' +
f'desc={self.description}, shape: {self.shape}')
def __repr__(self):
return f'{self.__class__}: {self.__str__()}'
[docs]
@dataclass
class FeatureContext(PersistableContainer):
"""Data created by coding and meant to be pickled on the file system.
:see EncodableFeatureVectorizer.encode:
"""
feature_id: str = field()
"""The feature id of the :class:`.FeatureVectorizer` that created this
context.
"""
def __str__(self):
return f'{self.__class__.__name__} ({self.feature_id})'
[docs]
@dataclass
class NullFeatureContext(FeatureContext):
"""A no-op feature context used for cases such as prediction batches with
data points that have no labels.
:see: :meth:`~zensols.deeplearn.batch.BatchStash.create_prediction`
:see: :class:`~zensols.deeplearn.batch.Batch`
"""
pass
[docs]
@dataclass
class TensorFeatureContext(FeatureContext):
"""A context that encodes data directly to a tensor. This tensor could be a
sparse matrix becomes dense during the decoding process.
"""
tensor: Tensor = field()
"""The output tensor of the encoding phase."""
[docs]
def deallocate(self):
super().deallocate()
if hasattr(self, 'tensor'):
del self.tensor
[docs]
def write(self, depth: int = 0, writer: TextIOBase = sys.stdout):
super().write(depth, writer)
self._write_line(f'tensor shape: {self.tensor.shape}', depth, writer)
def __str__(self):
tstr = f'{self.tensor.shape}' if self.tensor is not None else '<none>'
return f'{super().__str__()}: {tstr}'
def __repr__(self):
return self.__str__()
[docs]
@dataclass
class SparseTensorFeatureContext(FeatureContext):
"""Contains data that was encded from a dense matrix as a sparse matrix and
back. Using torch sparse matrices currently lead to deadlocking in child
proceesses, so use scipy :class:``csr_matrix`` is used instead.
"""
USE_SPARSE: ClassVar[bool] = True
"""Whether or not to enable sparse matrix serialization. Otherwise, torch
tensors (no conversion) are used.
"""
sparse_data: Union[Tuple[Tuple[csr_matrix, int]], Tensor] = field()
"""The sparse array data."""
@property
def sparse_arr(self) -> Tuple[csr_matrix]:
assert isinstance(self.sparse_data[0], tuple)
return self.sparse_data[0]
[docs]
@classmethod
def to_sparse(cls, arr: Tensor) -> Tuple[csr_matrix]:
narr = arr.numpy()
tdim = len(arr.shape)
if tdim == 3:
narrs = tuple(map(lambda i: narr[i], range(narr.shape[0])))
elif tdim == 2 or tdim == 1:
narrs = (narr,)
else:
raise VectorizerError('Tensors of dimensions higher than ' +
f'3 not supported: {arr.shape}')
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f'creating sparse matrix: {arr.shape}')
mats = tuple(map(lambda m: sparse.csr_matrix(m), narrs))
return (mats, tdim)
[docs]
@classmethod
def instance(cls, feature_id: str, arr: Tensor,
torch_config: TorchConfig):
arr = arr.cpu()
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f'encoding in to sparse tensor: {arr.shape}')
if cls.USE_SPARSE:
if arr.dtype == torch.float16:
raise VectorizerError(
('Type float 16 is not supported (scipy 1.9.3); set ' +
'SparseTensorFeatureContext.USE_SPARSE to False or ' +
'deeplearn_default.fp_precision higher'))
with warnings.catch_warnings():
# silence the numpy warningsin scipy package
#
# scipy/sparse/_sputils.py:44: DeprecationWarning:
# np.find_common_type is deprecated. Please use
# `np.result_type` or `np.promote_types`.
warnings.filterwarnings(
'ignore',
message=r"^np.find_common_type is deprecated",
category=DeprecationWarning)
sarr = cls.to_sparse(arr)
else:
sarr = arr
return cls(feature_id, sarr)
[docs]
def to_tensor(self, torch_config: TorchConfig) -> Tensor:
if isinstance(self.sparse_arr, Tensor):
tarr = self.sparse_arr
else:
narr: Tuple[csr_matrix]
tdim: int
narr, tdim = self.sparse_data
narrs = tuple(map(lambda sm: torch.from_numpy(sm.todense()), narr))
if len(narrs) == 1:
tarr = narrs[0]
else:
tarr = torch.stack(narrs)
dim_diff = len(tarr.shape) - tdim
if dim_diff > 0:
for _ in range(dim_diff):
tarr = tarr.squeeze(0)
elif dim_diff < 0:
for _ in range(-dim_diff):
tarr = tarr.unsqueeze(0)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f'decoded sparce matrix to: {tarr.shape}')
return tarr
[docs]
@dataclass
class MultiFeatureContext(FeatureContext):
"""A composite context that contains a tuple of other contexts.
"""
contexts: Tuple[FeatureContext] = field()
"""The subordinate contexts."""
@property
def is_empty(self) -> bool:
cnt = sum(1 for _ in filter(
lambda c: not isinstance(c, NullFeatureContext), self.contexts))
return cnt == 0
[docs]
def deallocate(self):
super().deallocate()
if hasattr(self, 'contexts'):
self._try_deallocate(self.contexts)
del self.contexts