Source code for zensols.deeplearn.layer.linear

"""Convenience classes for linear layers.

"""
__author__ = 'Paul Landes'

from typing import Any, Tuple
from dataclasses import dataclass, field
import logging
import sys
import torch
from torch import nn
from zensols.deeplearn import (
    ActivationNetworkSettings,
    DropoutNetworkSettings,
    BatchNormNetworkSettings,
)
from zensols.deeplearn.model import BaseNetworkModule
from . import LayerError

logger = logging.getLogger(__name__)


[docs] @dataclass class DeepLinearNetworkSettings(ActivationNetworkSettings, DropoutNetworkSettings, BatchNormNetworkSettings): """Settings for a deep fully connected network using :class:`.DeepLinear`. """ ## TODO: centralize on either in_features or input_size: # embedding_output_size, RecurrentCRFNetworkSettings.input_size in_features: int = field() """The number of features to the first layer.""" out_features: int = field() """The number of features as output from the last layer.""" middle_features: Tuple[Any] = field() """The number of features in the middle layers; if ``proportions`` is ``True``, then each number is how much to grow or shrink as a percetage of the last layer, otherwise, it's the number of features. """ proportions: bool = field() """Whether or not to interpret ``middle_features`` as a proportion of the previous layer or use directly as the size of the middle layer. """ repeats: int = field() """The number of repeats of the :obj:`middle_features` configuration."""
[docs] def get_module_class_name(self) -> str: return __name__ + '.DeepLinear'
[docs] class DeepLinear(BaseNetworkModule): """A layer that has contains one more nested layers, including batch normalization and activation. The input and output layer shapes are given and an optional 0 or more middle layers are given as percent changes in size or exact numbers. If the network settings are configured to have batch normalization, batch normalization layers are added after each linear layer. The drop out and activation function (if any) are applied in between each layer allowing other drop outs and activation functions to be applied before and after. Note that the activation is implemented as a function, and not a layer. For example, if batch normalization and an activation function is configured and two layers are configured, the network is configured as: 1. linear 2. batch normalization 3. activation 5. dropout 6. linear 7. batch normalization 8. activation 9. dropout The module also provides the output features of each layer with :py:meth:`n_features_after_layer` and ability to forward though only the first given set of layers with :meth:`forward_n_layers`. """ MODULE_NAME = 'linear'
[docs] def __init__(self, net_settings: DeepLinearNetworkSettings, sub_logger: logging.Logger = None): """Initialize the deep linear layer. :param net_settings: the deep linear layer configuration :param sub_logger: the logger to use for the forward process in this layer """ super().__init__(net_settings, sub_logger) ns = net_settings last_feat = ns.in_features lin_layers = [] bnorm_layers = [] if self.logger.isEnabledFor(logging.DEBUG): self._debug(f'in: {ns.in_features}, ' + f'middle: {ns.middle_features}, ' + f'out: {ns.out_features}') for mf in ns.middle_features: for i in range(ns.repeats): if ns.proportions: next_feat = int(last_feat * mf) else: next_feat = int(mf) self._add_layer(last_feat, next_feat, ns.dropout, lin_layers, bnorm_layers) last_feat = next_feat if ns.out_features is not None: self._add_layer(last_feat, ns.out_features, ns.dropout, lin_layers, bnorm_layers) self.lin_layers = nn.Sequential(*lin_layers) if len(bnorm_layers) > 0: self.bnorm_layers = nn.Sequential(*bnorm_layers) else: self.bnorm_layers = None
[docs] def deallocate(self): super().deallocate() if hasattr(self, 'lin_layers'): del self.lin_layers
def _add_layer(self, in_features: int, out_features: int, dropout: float, lin_layers: list, bnorm_layers): ns = self.net_settings n_layer = len(lin_layers) if self.logger.isEnabledFor(logging.DEBUG): self._debug(f'add {n_layer}: in={in_features} out={out_features}') lin_layer = nn.Linear(in_features, out_features) lin_layers.append(lin_layer) if ns.batch_norm_d is not None: if out_features is None: raise LayerError('Bad out features') if ns.batch_norm_features is None: bn_features = out_features else: bn_features = ns.batch_norm_features layer = ns.create_batch_norm_layer(ns.batch_norm_d, bn_features) if layer is None: raise LayerError(f'Bad layer params: D={ns.batch_norm_d}, ' + f'features={out_features}') bnorm_layers.append(layer)
[docs] def get_linear_layers(self) -> Tuple[nn.Module]: """Return all linear layers. """ return tuple(self.lin_layers)
[docs] def get_batch_norm_layers(self) -> Tuple[nn.Module]: """Return all batch normalize layers. """ if self.bnorm_layers is not None: return tuple(self.bnorm_layers)
@property def out_features(self) -> int: """The number of features output from all layers of this module. """ n_layers = len(self.get_linear_layers()) return self.n_features_after_layer(n_layers - 1)
[docs] def n_features_after_layer(self, nth_layer) -> int: """Get the output features of the Nth (0 index based) layer. :param nth_layer: the layer to use for getting the output features """ return self.get_linear_layers()[nth_layer].out_features
[docs] def forward_n_layers(self, x: torch.Tensor, n_layers: int, full_forward: bool = False) -> torch.Tensor: """Forward throught the first 0 index based N layers. :param n_layers: the number of layers to forward through (0-based index) :param full_forward: if ``True``, also return the full forward as a second parameter :return: the tensor output of all layers or a tuple of ``(N-th layer, all layers)`` """ return self._forward(x, n_layers, full_forward)
def _forward(self, x: torch.Tensor, n_layers: int = sys.maxsize, full_forward: bool = False) -> torch.Tensor: lin_layers = self.get_linear_layers() bnorm_layers = self.get_batch_norm_layers() n_layers = min(len(lin_layers) - 1, n_layers) x_ret = None if self.logger.isEnabledFor(logging.DEBUG): self._debug(f'linear: num layers: {len(lin_layers)}') self._shape_debug('input', x) for i, layer in enumerate(lin_layers): x = layer(x) self._shape_debug('linear', x) if bnorm_layers is not None: blayer = bnorm_layers[i] if self.logger.isEnabledFor(logging.DEBUG): self._debug(f'batch norm: {blayer}') x = blayer(x) x = self._forward_activation(x) x = self._forward_dropout(x) if i == n_layers: x_ret = x if self.logger.isEnabledFor(logging.DEBUG): self._debug(f'reached {i}th layer = n_layers') self._shape_debug('x_ret', x_ret) if not full_forward: self._debug('breaking') break if full_forward: return x_ret, x else: return x