"""This file contains a convenience wrapper around RNN, GRU and LSTM modules inPyTorch."""__author__='Paul Landes'fromtypingimportUnion,Tuplefromdataclassesimportdataclass,fieldimportloggingimporttorchfromtorchimportnnfromtorchimportTensorfromzensols.configimportClassImporterfromzensols.deeplearnimportDropoutNetworkSettingsfromzensols.deeplearn.modelimportBaseNetworkModulefrom.importLayerErrorlogger=logging.getLogger(__name__)
[docs]@dataclassclassRecurrentAggregationNetworkSettings(DropoutNetworkSettings):"""Settings for a recurrent neural network. This configures a :class:`.RecurrentAggregation` layer. """network_type:str=field()"""One of ``rnn``, ``lstm`` or ``gru``."""aggregation:str=field()"""A convenience operation to aggregate the parameters; this is one of: ``max``: return the max of the output states ``ave``: return the average of the output states ``last``: return the last output state ``none``: do not apply an aggregation function. """bidirectional:bool=field()"""Whether or not the network is bidirectional."""input_size:int=field()"""The input size to the network."""hidden_size:int=field()"""The size of the hidden states of the network."""num_layers:int=field()"""The number of *"stacked"* layers."""
def__str__(self)->str:return(f'linear: {self.input_size} -> '+f'{self.hidden_size} X {self.num_layers}')
[docs]classRecurrentAggregation(BaseNetworkModule):"""A recurrent neural network model with an output aggregation. This includes RNNs, LSTMs and GRUs. """MODULE_NAME='recur'
[docs]def__init__(self,net_settings:RecurrentAggregationNetworkSettings,sub_logger:logging.Logger=None):"""Initialize the recurrent layer. :param net_settings: the reccurent layer configuration :param sub_logger: the logger to use for the forward process in this layer """super().__init__(net_settings,sub_logger)ns=net_settingsifself.logger.isEnabledFor(logging.INFO):self.logger.info(f'creating {ns.network_type} network')class_name=f'torch.nn.{ns.network_type.upper()}'ci=ClassImporter(class_name,reload=False)hidden_size=ns.hidden_size//(2ifns.bidirectionalelse1)param={'input_size':ns.input_size,'hidden_size':hidden_size,'num_layers':ns.num_layers,'bidirectional':ns.bidirectional,'batch_first':True}ifns.num_layers>1andns.dropoutisnotNone:# UserWarning: dropout option adds dropout after all but last# recurrent layer, so non-zero dropout expects num_layers greater# than 1param['dropout']=ns.dropoutself.rnn:nn.RNNBase=ci.instance(**param)ifself.logger.isEnabledFor(logging.DEBUG):self.logger.debug(f'created {type(self.rnn)} with {param}')
@propertydefout_features(self)->int:"""The number of features output from all layers of this module. """ns=self.net_settingsreturnns.hidden_sizedef_forward(self,x:Tensor,x_init:Tensor=None)-> \
Union[Tensor,Tuple[Tensor,Tensor,Tensor]]:self._shape_debug('input',x)ifx_initisNone:x,hidden=self.rnn(x)else:x,hidden=self.rnn(x,x_init)self._shape_debug('recur',x)agg=self.net_settings.aggregationifagg=='max':x=torch.max(x,dim=1)[0]self._shape_debug('max out shape',x)elifagg=='ave':x=torch.mean(x,dim=1)self._shape_debug('ave out shape',x)elifagg=='last':x=x[:,-1,:]self._shape_debug('last out shape',x)elifagg=='none':passelse:raiseLayerError(f'Unknown aggregate function: {agg}')self._shape_debug('aggregation',x)returnx,hidden