Source code for zensols.deepnlp.transformer.tokenizer

"""The tokenizer object.

"""
__author__ = 'Paul Landes'

from typing import List, Dict, Set, ClassVar, Any
from dataclasses import dataclass, field
import logging
from frozendict import frozendict
import torch
from transformers import PreTrainedTokenizer
from transformers.tokenization_utils_base import BatchEncoding
from zensols.nlp import FeatureDocument
from zensols.persist import persisted, PersistableContainer
from zensols.deeplearn import TorchConfig
from zensols.deepnlp.transformer import TransformerResource
from . import TransformerError, TokenizedFeatureDocument

logger = logging.getLogger(__name__)


[docs] @dataclass class TransformerDocumentTokenizer(PersistableContainer): """Creates instances of :class:`.TokenziedFeatureDocument` using a HuggingFace :class:`~transformers.PreTrainedTokenizer`. """ DEFAULT_PARAMS: ClassVar[Dict[str, Any]] = frozendict({ 'return_offsets_mapping': True, 'is_split_into_words': True, 'return_special_tokens_mask': True, 'padding': 'longest'}) """Default parameters for the HuggingFace tokenizer. These get overriden by the ``tokenizer_kwargs`` in :meth:`tokenize` and the processing of value :obj:`word_piece_token_length`. """ resource: TransformerResource = field() """Contains the model used to create the tokenizer.""" word_piece_token_length: int = field(default=None) """The max number of word piece tokens. The word piece length is always the same or greater in count than linguistic tokens because the word piece algorithm tokenizes on characters. If this value is less than 0, than do not fix sentence lengths. If the value is 0 (default), then truncate to the model's longest max lenght. Otherwise, if this value is ``None``, set the length to the model's longest max length using the model's ``model_max_length`` value. Setting this to a value to 0, making documents multi-length, has the potential of creating token spans longer than the model can tolerate (usually 512 word piece tokens). In these cases, this value must be set to (or lower) than the model's ``model_max_length``. Tokenization padding is on by default. :see: `HF Docs <https://huggingface.co/docs/transformers/pad_truncation>`_ """ params: Dict[str, Any] = field(default=None) """Additional parameters given to the :class:`transformers.PreTrainedTokenizer`. """ def __post_init__(self): super().__init__() if self.word_piece_token_length is None: self.word_piece_token_length = \ self.resource.tokenizer.model_max_length @property def pretrained_tokenizer(self) -> PreTrainedTokenizer: """The HuggingFace tokenized used to create tokenized documents.""" return self.resource.tokenizer @property def token_max_length(self) -> int: """The word piece token maximum length supported by the model.""" if self.word_piece_token_length is None or \ self.word_piece_token_length == 0: return self.pretrained_tokenizer.model_max_length return self.word_piece_token_length @property @persisted('_id2tok') def id2tok(self) -> Dict[int, str]: """A mapping from the HuggingFace tokenizer's vocabulary to it's word piece equivalent. """ vocab = self.pretrained_tokenizer.vocab return {vocab[k]: k for k in vocab.keys()} @property @persisted('_all_special_tokens') def all_special_tokens(self) -> Set[str]: """Special tokens used by the model (such BERT's as ``[CLS]`` and ``[SEP]`` tokens). """ return frozenset(self.pretrained_tokenizer.all_special_tokens)
[docs] def tokenize(self, doc: FeatureDocument, tokenizer_kwargs: Dict[str, Any] = None) -> \ TokenizedFeatureDocument: """Tokenize a feature document in a form that's easy to inspect and provide to :class:`.TransformerEmbedding` to transform. :param doc: the document to tokenize """ if not self.resource.tokenizer.is_fast: raise TransformerError( 'only fast tokenizers are supported for needed offset mapping') sents = list(map(lambda sent: list( map(lambda tok: tok.text, sent)), doc)) return self._from_tokens(sents, doc, tokenizer_kwargs)
def _from_tokens(self, sents: List[List[str]], doc: FeatureDocument, tokenizer_kwargs: Dict[str, Any] = None) -> \ TokenizedFeatureDocument: torch_config: TorchConfig = self.resource.torch_config tlen: int = self.word_piece_token_length tokenizer: PreTrainedTokenizer = self.pretrained_tokenizer params: Dict[str, bool] = dict(self.DEFAULT_PARAMS) if self.params is not None: params.update(self.params) for i, sent in enumerate(sents): if len(sent) == 0: raise TransformerError( f'Sentence {i} is empty: can not tokenize') if logger.isEnabledFor(logging.DEBUG): logger.debug(f'parsing {sents} with token length: {tlen}') if tlen > 0: params.update({'truncation': True, 'max_length': tlen}) else: params.update({'truncation': False}) if tokenizer_kwargs is not None: params.update(tokenizer_kwargs) if logger.isEnabledFor(logging.DEBUG): logger.debug(f'using tokenizer parameters: {params}') tok_dat: BatchEncoding = tokenizer(sents, **params) if logger.isEnabledFor(logging.DEBUG): logger.debug(f"lengths: {[len(i) for i in tok_dat['input_ids']]}") logger.debug(f"inputs: {tok_dat['input_ids']}") input_ids = tok_dat.input_ids char_offsets = tok_dat.offset_mapping boundary_tokens = (tok_dat.special_tokens_mask[0][0]) == 1 sent_offsets = tuple( map(lambda s: tuple(map(lambda x: -1 if x is None else x, s)), map(lambda si: tok_dat.word_ids(batch_index=si), range(len(input_ids))))) if logger.isEnabledFor(logging.DEBUG): for six, tids in enumerate(sent_offsets): logger.debug(f'tok ids: {tids}') for stix, tix in enumerate(tids): bid = tok_dat['input_ids'][six][stix] wtok = self.id2tok[bid] if tix >= 0: stok = sents[six][tix] else: stok = '-' logger.debug( f'sent={six}, idx={tix}, id={bid}: {wtok} -> {stok}') tok_data = [input_ids, tok_dat.attention_mask, sent_offsets] if hasattr(tok_dat, 'token_type_ids'): tok_data.append(tok_dat.token_type_ids) arr = torch_config.singleton(tok_data, dtype=torch.long) if logger.isEnabledFor(logging.DEBUG): logger.debug(f'tok doc mat: shape={arr.shape}, dtype={arr.dtype}') return TokenizedFeatureDocument( tensor=arr, # needed by expander vectorizer boundary_tokens=boundary_tokens, char_offsets=char_offsets, feature=doc, id2tok=self.id2tok) def __call__(self, doc: FeatureDocument, tokenizer_kwargs: Dict[str, Any] = None) -> \ TokenizedFeatureDocument: return self.tokenize(doc, tokenizer_kwargs)