Source code for zensols.deepnlp.layer.embrecurcrf

"""Embedding input layer classes.

"""
__author__ = 'Paul Landes'

from typing import Tuple, List, Optional, Union
from dataclasses import dataclass, field
import logging
import torch
from torch import Tensor
from zensols.deeplearn import ModelError, DatasetSplitType
from zensols.deeplearn.model import (
    SequenceNetworkModule, SequenceNetworkContext, SequenceNetworkOutput
)
from zensols.deeplearn.batch import Batch
from zensols.deeplearn.layer import (
    RecurrentCRFNetworkSettings,
    RecurrentCRF,
)
from zensols.deepnlp.layer import (
    EmbeddingNetworkSettings,
    EmbeddingNetworkModule,
)

logger = logging.getLogger(__name__)


[docs] @dataclass class EmbeddedRecurrentCRFSettings(EmbeddingNetworkSettings): """A utility container settings class for convulsion network models. """ recurrent_crf_settings: RecurrentCRFNetworkSettings = field() """The RNN settings (configure this with an LSTM for (Bi)LSTM CRFs).""" mask_attribute: str = field() """The vectorizer attribute name for the mask feature.""" tensor_predictions: bool = field(default=False) """Whether or not to return predictions as tensors. There are currently no identified use cases to do this as setting this to ``True`` will inflate performance metrics. This is because the batch iterator will create a tensor with the entire batch adding a lot of default padded value that will be counted as results. """ use_crf: bool = field(default=True)
[docs] def get_module_class_name(self) -> str: return __name__ + '.EmbeddedRecurrentCRF'
[docs] class EmbeddedRecurrentCRF(EmbeddingNetworkModule, SequenceNetworkModule): """A recurrent neural network composed of an embedding input, an recurrent network, and a linear conditional random field output layer. When configured with an LSTM, this becomes a (Bi)LSTM-CRF. More specifically, this network has the following: 1. Input embeddings mapped from tokens. 2. Recurrent network (i.e. LSTM). 3. Fully connected feed forward deep linear layer(s) as the decoder. 4. Linear chain conditional random field (CRF) layer. 5. Output the labels. """ MODULE_NAME = 'emb-recur-crf'
[docs] def __init__(self, net_settings: EmbeddedRecurrentCRFSettings, sub_logger: logging.Logger = None): super().__init__(net_settings, sub_logger) ns = self.net_settings rc = ns.recurrent_crf_settings rc.input_size = self.embedding_output_size self.mask_attribute = ns.mask_attribute if self.logger.isEnabledFor(logging.DEBUG): self.logger.debug(f'recur emb settings: {rc}') self.recurcrf: RecurrentCRF = rc.create_module( sub_logger=sub_logger, use_crf=ns.use_crf)
[docs] def deallocate(self): super().deallocate() self.recurcrf.deallocate()
def _get_mask(self, batch: Batch) -> Tensor: mask = batch[self.mask_attribute] self._shape_debug('mask', mask) return mask def _forward_train_with_crf(self, batch: Batch) -> Tensor: labels = batch.get_labels() self._shape_debug('labels', labels) mask = self._get_mask(batch) self._shape_debug('mask', mask) x = super()._forward(batch) self._shape_debug('super emb', x) x = self.recurcrf.forward(x, mask, labels) self._shape_debug('recur', x) return x def _forward_train_no_crf(self, batch: Batch, context: SequenceNetworkContext) -> \ List[List[int]]: recur_crf: RecurrentCRF = self.recurcrf # no implementation yet for prediction sans-training labels: Optional[Tensor] = batch.get_labels() emb: Tensor = EmbeddingNetworkModule._forward(self, batch) self._shape_debug('embedding', emb) logits: Tensor = recur_crf.forward_recur_decode(emb) self._shape_debug('logits', logits) logits_flat: Tensor = logits.flatten(0, 1) labels_flat: Tensor = labels.flatten(0, 1) self._shape_debug('flat logits', logits_flat) self._shape_debug('flat labels', labels_flat) loss: Tensor = context.criterion(logits_flat, labels_flat) if self.logger.isEnabledFor(logging.DEBUG): self._debug(f'loss: {loss}') pred_labels: Tensor = logits.argmax(dim=2) self._shape_debug('predictions (agg)', pred_labels) mask: Tensor = self._get_mask(batch) assert len(mask.size()) == 2 pred_lsts: List[List[int]] = [] for bix in range(mask.size(0)): bmask = mask[bix] plst = torch.masked_select(pred_labels[bix], bmask) pred_lsts.append(plst.tolist()) return pred_lsts, loss def _decode(self, batch: Batch, add_loss: bool) -> Tuple[Tensor, Tensor]: loss: Tensor = None mask: Tensor = self._get_mask(batch) self._shape_debug('mask', mask) x: Tensor = super()._forward(batch) self._shape_debug('super emb', x) if add_loss: labels = batch.get_labels() loss = self.recurcrf.forward(x, mask, labels) x, score = self.recurcrf.decode(x, mask) self._debug(f'recur {len(x)}') self._shape_debug('score', score) return x, loss, score def _map_labels(self, batch: Batch, context: SequenceNetworkContext, labels: Union[List[List[int]], Tensor]) -> List[List[int]]: return labels def _shape_or_list_debug(self, msg: str, data: Union[List[List[int]], Tensor], full: bool = False): if self.logger.isEnabledFor(logging.DEBUG) or True: if data is None: self.logger.debug(f'{msg}: None') else: if isinstance(data, Tensor): self._shape_debug(msg, data) else: dtype = 'none' if len(data) > 0: dtype = type(data[0]) if dtype == list: dtype = f'{dtype} ({len(data)})' self.logger.debug( f'{msg}: length={len(data)}, type={dtype}') if full: from zensols.deeplearn import printopts with printopts(profile='full'): self.logger.debug('full data:\n' + str(data)) def _forward(self, batch: Batch, context: SequenceNetworkContext) -> \ SequenceNetworkOutput: use_crf = self.net_settings.use_crf split_type: DatasetSplitType = context.split_type preds: List[List[int]] = None labels: Optional[Tensor] = batch.get_labels() loss: Tensor = None score: Tensor = None if self.logger.isEnabledFor(logging.DEBUG): self.logger.debug(f'forward on splt: {context.split_type}') if context.split_type != DatasetSplitType.train and self.training: raise ModelError( f'Attempting to use split {split_type} while training') if context.split_type == DatasetSplitType.train: if use_crf: loss = self._forward_train_with_crf(batch) else: preds, loss = self._forward_train_no_crf(batch, context) elif context.split_type == DatasetSplitType.validation: if use_crf: preds, loss, score = self._decode(batch, True) else: preds, loss = self._forward_train_no_crf(batch, context) elif context.split_type == DatasetSplitType.test: if use_crf: preds, _, score = self._decode(batch, False) loss = batch.torch_config.singleton([0], dtype=torch.float32) else: preds, loss = self._forward_train_no_crf(batch, context) else: raise ModelError(f'Unknown data split type: {split_type}') # list of lists of the predictions, which are the CRF output when # enabled if preds is not None: preds = self._map_labels(batch, context, preds) # padded tensor of shape (batch, data i.e. token length) if labels is not None: labels = self._map_labels(batch, context, labels) self._shape_or_list_debug('output preds', preds) self._shape_or_list_debug('output labels', labels) out = SequenceNetworkOutput(preds, loss, score, labels) if preds is not None and labels is not None and len(labels.size()) > 1: out.righsize_labels(preds) return out