Source code for zensols.deeplearn.layer.recur

"""This file contains a convenience wrapper around RNN, GRU and LSTM modules in
PyTorch.

"""
__author__ = 'Paul Landes'

from typing import Union, Tuple
from dataclasses import dataclass, field
import logging
import torch
from torch import nn
from torch import Tensor
from zensols.config import ClassImporter
from zensols.deeplearn import DropoutNetworkSettings
from zensols.deeplearn.model import BaseNetworkModule
from . import LayerError

logger = logging.getLogger(__name__)


[docs] @dataclass class RecurrentAggregationNetworkSettings(DropoutNetworkSettings): """Settings for a recurrent neural network. This configures a :class:`.RecurrentAggregation` layer. """ network_type: str = field() """One of ``rnn``, ``lstm`` or ``gru``.""" aggregation: str = field() """A convenience operation to aggregate the parameters; this is one of: ``max``: return the max of the output states ``ave``: return the average of the output states ``last``: return the last output state ``none``: do not apply an aggregation function. """ bidirectional: bool = field() """Whether or not the network is bidirectional.""" input_size: int = field() """The input size to the network.""" hidden_size: int = field() """The size of the hidden states of the network.""" num_layers: int = field() """The number of *"stacked"* layers."""
[docs] def get_module_class_name(self) -> str: return __name__ + '.RecurrentAggregation'
[docs] class RecurrentAggregation(BaseNetworkModule): """A recurrent neural network model with an output aggregation. This includes RNNs, LSTMs and GRUs. """ MODULE_NAME = 'recur'
[docs] def __init__(self, net_settings: RecurrentAggregationNetworkSettings, sub_logger: logging.Logger = None): """Initialize the recurrent layer. :param net_settings: the reccurent layer configuration :param sub_logger: the logger to use for the forward process in this layer """ super().__init__(net_settings, sub_logger) ns = net_settings if self.logger.isEnabledFor(logging.INFO): self.logger.info(f'creating {ns.network_type} network') class_name = f'torch.nn.{ns.network_type.upper()}' ci = ClassImporter(class_name, reload=False) hidden_size = ns.hidden_size // (2 if ns.bidirectional else 1) param = {'input_size': ns.input_size, 'hidden_size': hidden_size, 'num_layers': ns.num_layers, 'bidirectional': ns.bidirectional, 'batch_first': True} if ns.num_layers > 1 and ns.dropout is not None: # UserWarning: dropout option adds dropout after all but last # recurrent layer, so non-zero dropout expects num_layers greater # than 1 param['dropout'] = ns.dropout self.rnn: nn.RNNBase = ci.instance(**param) if self.logger.isEnabledFor(logging.DEBUG): self.logger.debug(f'created {type(self.rnn)} with {param}')
[docs] def deallocate(self): super().deallocate() if hasattr(self, 'rnn'): del self.rnn
@property def out_features(self) -> int: """The number of features output from all layers of this module. """ ns = self.net_settings return ns.hidden_size def _forward(self, x: Tensor, x_init: Tensor = None) -> \ Union[Tensor, Tuple[Tensor, Tensor, Tensor]]: self._shape_debug('input', x) if x_init is None: x, hidden = self.rnn(x) else: x, hidden = self.rnn(x, x_init) self._shape_debug('recur', x) agg = self.net_settings.aggregation if agg == 'max': x = torch.max(x, dim=1)[0] self._shape_debug('max out shape', x) elif agg == 'ave': x = torch.mean(x, dim=1) self._shape_debug('ave out shape', x) elif agg == 'last': x = x[:, -1, :] self._shape_debug('last out shape', x) elif agg == 'none': pass else: raise LayerError(f'Unknown aggregate function: {agg}') self._shape_debug('aggregation', x) return x, hidden