Source code for zensols.deepnlp.transformer.pred

"""Predictions output for transformer models.

__author__ = 'Paul Landes'

from typing import Callable, List, Tuple, Dict, Iterable, Any
from dataclasses import dataclass, field
import logging
import pandas as pd
from torch import Tensor
from zensols.nlp import FeatureToken, FeatureSentence, FeatureDocument
from zensols.deeplearn.batch import Batch, DataPoint
from zensols.deeplearn.result import SequencePredictionsDataFrameFactory
from . import TokenizedDocument

logger = logging.getLogger(__name__)

[docs] @dataclass class TransformerSequencePredictionsDataFrameFactory( SequencePredictionsDataFrameFactory): """Like the super class but create predictions for transformer sequence models. By default, transformer input is truncated at the model's max token length (usually 512 word piece tokens). It then truncate the tokens that are added as the ``text`` column from (configured by default) :class:`..classify.TokenClassifyModelFacade`. For all predictions where the sequence passed the model's maximum, this class maps that last word piece token output to the respective token in the :obj:`predictions_dataframe_factory_class` instance's ``transform`` output. """ embedded_document_attribute: str = field(default=None) """The :obj:`~zensols.deeplearn.batch.domain.Batch` attribute key for the tensor that contains the vectorized document. """ def _trunc_tokens(self, batch: Batch) -> Iterable[Tuple[FeatureToken, ...]]: """Return tokens truncated at the length of the last word piece token. :param batch: contains the ``data_points`` with sentences to truncate :return: the truncated tokens for each data point in ``batch`` """ # merge documents from the data points into a document for the batch dps_doc: FeatureDocument = FeatureDocument.combine_documents( map(lambda dp: dp.doc, batch.data_points)) # re-hydrate the vectorized document from the batch tensor emb: Tensor = batch[self.embedded_document_attribute] tdoc: TokenizedDocument = TokenizedDocument.from_tensor(emb) # map word piece tokens to feature document tokens sent_maps: List[Dict[str, Any]] = tdoc.map_to_word_pieces( sentences=dps_doc, includes={'map', 'sent'}) for dpix, dp in enumerate(batch.data_points): tmap: Tuple[FeatureToken, Tuple[Tuple[str, int, int], ...]] = \ sent_maps[dpix]['map'] yield len(tmap) def _calc_len(self, batch: Batch) -> int: trunc_lens: Tuple[int] = tuple(self._trunc_tokens(batch)) batch._trunc_lens = trunc_lens return sum(map(lambda tl: tl, trunc_lens)) def _transform_dataframe(self, batch: Batch, labs: List[str], preds: List[str]): dfs: List[pd.DataFrame] = [] start: int = 0 transform: Callable = self.data_point_transform self._assert_label_pred_batch_size(batch, labs, preds, False) dp: DataPoint tl: int for dp, tl in zip(batch.data_points, batch._trunc_lens): end: int = start + tl df = pd.DataFrame({ self.ID_COL:, self.LABEL_COL: labs[start:end], self.PREDICTION_COL: preds[start:end]}) dp_data: Tuple[Tuple[str, ...]] = transform(dp) if len(dp_data) != tl and logger.isEnabledFor(logging.WARNING): sent_str: str = '' if hasattr(dp, 'doc'): doc: FeatureDocument = dp.doc sent_str: str = f' for document: {doc}' logger.warning( f'trimming outcomes from {len(dp_data)} ' + f'to word piece max (equivalent) {tl}{sent_str}') dp_data = dp_data[:tl] df[list(self.column_names)] = dp_data dfs.append(df) start = end return pd.concat(dfs)