Source code for zensols.deeplearn.model.module

"""Base class PyTorch module and utilities.

"""
__author__ = 'Paul Landes'

from types import ModuleType
from typing import Union, Type, ClassVar
from abc import abstractmethod, ABCMeta
import logging
import torch
from torch import nn
from torch import Tensor
from zensols.introspect import ClassImporter
from zensols.persist import PersistableContainer
from zensols.deeplearn import (
    ModelError,
    NetworkSettings,
    ActivationNetworkSettings,
    DropoutNetworkSettings,
    BatchNormNetworkSettings,
    EarlyBailError,
)
from zensols.deeplearn.batch import Batch

logger = logging.getLogger(__name__)


[docs] class DebugModule(nn.Module): """A utility base class that makes logging more understandable. """ DEBUG_DEVICE: ClassVar[bool] = False """If ``True``, add tensor devices to log messages.""" DEBUG_TYPE: ClassVar[bool] = False """If ``True``, add tensor shapes to log messages.""" DEBUG_CLASS: ClassVar[bool] = True """If ``True``, add the logging class to log messages.""" MODULE_NAME: ClassVar[str] = None """The module name used in the logging message. This is set in each inherited class. """ _DEBUG_MESSAGE_MAX_LEN: ClassVar[int] = 100
[docs] def __init__(self, sub_logger: logging.Logger = None): """Initialize. :param sub_logger: used to log activity in this module so they logged module comes from some parent model """ super().__init__() if sub_logger is None: self.logger = self._resolve_class_logger() else: self.logger = sub_logger
def _resolve_class_logger(self) -> logging.Logger: cls: Type = self.__class__ mod: ModuleType = ClassImporter.get_module(cls.__module__, False) lg: logging.Logger = logger if hasattr(mod, 'logger'): lg_mod = getattr(mod, 'logger') if isinstance(lg_mod, logging.Logger): lg = lg_mod return lg def _debug(self, msg: str): """Debug a message using the module name in the description. """ if self.logger.isEnabledFor(logging.DEBUG): if msg is not None: if len(msg) > self._DEBUG_MESSAGE_MAX_LEN: msg = msg[:self._DEBUG_MESSAGE_MAX_LEN - 3] + '...' mname = self.MODULE_NAME cls = self.__class__.__name__ mname = '' if mname is None else f'[{mname}]' if self.DEBUG_CLASS: prefix = f'{cls}{mname}' else: prefix = mname if len(mname) > 0 else f'[{cls}]' self.logger.debug(f'{prefix} {msg}') def _shape_debug(self, msg: str, x: Tensor): """Debug a message using the module name in the description and include the shape. """ if self.logger.isEnabledFor(logging.DEBUG): if x is None: shape, device, dtype = [None] * 3 else: shape, device, dtype = x.shape, x.device, x.dtype msg = f'{msg} shape: {shape}' if self.DEBUG_DEVICE: msg += f', device: {device}' if self.DEBUG_TYPE: msg += f', type: {dtype}' self._debug(msg) def _bail(self): """A convenience method to assist in debugging. This is useful when the output isn't in the correct form for the :class:`.ModelExecutor`. """ self.logger.debug('-' * 60) raise EarlyBailError()
[docs] class BaseNetworkModule(DebugModule, PersistableContainer, metaclass=ABCMeta): """A utility base network module that contains ubiquitous, but optional layers, such as dropout and batch layeres, activation, etc. .. document private functions .. automethod:: _forward """
[docs] def __init__(self, net_settings: NetworkSettings, sub_logger: logging.Logger = None): """Initialize. :param net_settings: contains common layers such as droput and batch normalization :param sub_logger: used to log activity in this module so they logged module comes from some parent model """ super().__init__(sub_logger) self.net_settings = ns = net_settings if isinstance(ns, DropoutNetworkSettings): self.dropout = ns.dropout_layer else: self.dropout = None if isinstance(ns, BatchNormNetworkSettings) and \ (ns.batch_norm_d is not None or ns.batch_norm_features is not None): if ns.batch_norm_d is None or ns.batch_norm_features is None: raise ModelError('Both the dimension and features must be ' + f'set if one is set: {ns}') self.batch_norm = ns.batch_norm_layer else: self.batch_norm = None if isinstance(ns, ActivationNetworkSettings): self.activation_function = ns.activation_function else: self.activation_function = None
def _deallocate_children_modules(self): for layer in self.children(): self._try_deallocate(layer) def __getstate__(self): raise ModelError(f'Layers should not be pickeled: {self}')
[docs] @abstractmethod def _forward(self, x: Union[Batch, Tensor], *args, **kwargs) -> Tensor: """The model's forward implementation. Normal backward semantics are no different. :param x: the batch or tensor to train, validate or test on; the type depends on the needs of the model :param args: additional model specific arguments needed by classes that need more context :param kwargs: additional model specific arguments needed by classes that need more context """ pass
[docs] @staticmethod def device_from_module(module: nn.Module) -> torch.device: """Return the device on which the model is configured. :param module: the module containing the parameters used to get the device """ return next(module.parameters()).device
@property def device(self) -> torch.device: """Return the device on which the model is configured.""" return self.device_from_module(self) def _forward_dropout(self, x: Tensor) -> Tensor: """Forward the dropout if there is one configured. """ if self.dropout is None: self._debug('skipping unset dropout') else: if self.logger.isEnabledFor(logging.DEBUG): self._debug(f'dropout: {self.dropout}') x = self.dropout(x) return x def _forward_batch_norm(self, x: Tensor) -> Tensor: """Forward the batch normalization if there is one configured. """ if self.batch_norm is None: self._debug('skipping unset batch norm') else: if self.logger.isEnabledFor(logging.DEBUG): self._debug(f'batch norm: {self.batch_norm}') x = self.batch_norm(x) return x def _forward_activation(self, x: Tensor) -> Tensor: """Transform using the activation function if there is one configured. """ if self.activation_function is None: self._debug('skipping unset forward') else: if self.logger.isEnabledFor(logging.DEBUG): self._debug(f'activation: {self.activation_function}') x = self.activation_function(x) return x def _forward_batch_act_drop(self, x: Tensor) -> Tensor: """Forward convolution, batch normalization, pool, activation and dropout for those layers that are configured. :see: `Sunghean et al. <http://mipal.snu.ac.kr/images/1/16/Dropout_ACCV2016.pdf>`_ :see: `Ioffe et al. <https://arxiv.org/pdf/1502.03167.pdf>`_ """ x = self._forward_batch_norm(x) x = self._forward_activation(x) x = self._forward_dropout(x) return x
[docs] def forward(self, x: Union[Batch, Tensor], *args, **kwargs) -> Tensor: """Main forward takes a batch for top level modules, or a tensor for framework based layers. Return the transformed tensor. """ if self.logger.isEnabledFor(logging.DEBUG) and isinstance(x, Batch): self._debug(f'input batch: {x}') return self._forward(x, *args, **kwargs)