"""Contains classes that make up a text classification model.
"""
__author__ = 'Paul Landes'
from typing import Any
from dataclasses import dataclass, field
import logging
import torch
from zensols.deeplearn import DropoutNetworkSettings
from zensols.deeplearn.batch import Batch
from zensols.deeplearn.layer import (
DeepLinear,
DeepLinearNetworkSettings,
RecurrentAggregation,
RecurrentAggregationNetworkSettings,
)
from zensols.deepnlp.layer import (
EmbeddingNetworkSettings,
EmbeddingNetworkModule,
)
logger = logging.getLogger(__name__)
[docs]
@dataclass
class ClassifyNetworkSettings(DropoutNetworkSettings, EmbeddingNetworkSettings):
"""A utility container settings class for convulsion network models. This
class also updates the recurrent network's drop out settings when changed.
"""
recurrent_settings: RecurrentAggregationNetworkSettings = field()
"""Contains the confgiuration for the models RNN."""
linear_settings: DeepLinearNetworkSettings = field()
"""Contains the configuration for the model's FF *decoder*."""
def _set_option(self, name: str, value: Any):
super()._set_option(name, value)
# propogate dropout to recurrent network
if name == 'dropout' and hasattr(self, 'recurrent_settings'):
if self.recurrent_settings is not None:
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f'setting dropout: {value}')
self.recurrent_settings.dropout = value
self.linear_settings.dropout = value
[docs]
def get_module_class_name(self) -> str:
return __name__ + '.ClassifyNetwork'
[docs]
class ClassifyNetwork(EmbeddingNetworkModule):
"""A model that either allows for an RNN or a BERT transforemr to classify
text.
"""
MODULE_NAME = 'classify'
[docs]
def __init__(self, net_settings: ClassifyNetworkSettings):
super().__init__(net_settings, logger)
ns = self.net_settings
rs = ns.recurrent_settings
ls = ns.linear_settings
if logger.isEnabledFor(logging.DEBUG):
self._debug(f'embedding output size: {self.embedding_output_size}')
if rs is None:
ln_in_features = self.embedding_output_size + self.join_size
self.recur = None
else:
rs.input_size = self.embedding_output_size
self._debug(f'recur settings: {rs}')
self.recur = RecurrentAggregation(rs)
self._debug(f'embedding join size: {self.join_size}')
self.join_size += self.recur.out_features
self._debug(f'after lstm join size: {self.join_size}')
ln_in_features = self.join_size
if self.logger.isEnabledFor(logging.DEBUG):
self._debug(f'linear in size: {ln_in_features}')
ls.in_features = ln_in_features
if self.logger.isEnabledFor(logging.DEBUG):
self._debug(f'linear input settings: {ls}')
self.fc_deep = DeepLinear(ls, self.logger)
def _forward(self, batch: Batch) -> torch.Tensor:
if self.logger.isEnabledFor(logging.DEBUG):
self._debug('review batch:')
batch.write()
x = self.forward_embedding_features(batch)
self._shape_debug('embedding', x)
x = self.forward_token_features(batch, x)
self._shape_debug('token', x)
if self.recur is not None:
x = self.recur(x)[0]
self._shape_debug('lstm', x)
x = self.forward_document_features(batch, x)
x = self.fc_deep(x)
self._shape_debug('deep linear', x)
return x