Source code for zensols.deepnlp.layer.embed

"""An embedding layer module useful for models that use embeddings as input.

"""
__author__ = 'Paul Landes'

from typing import Dict, List, Tuple
from dataclasses import dataclass, field
from typing import Callable
import logging
import torch
from torch import nn
from torch import Tensor
from zensols.persist import Deallocatable
from zensols.deeplearn import ModelError
from zensols.deeplearn.vectorize import FeatureVectorizer
from zensols.deeplearn.model import BaseNetworkModule, DebugModule
from zensols.deeplearn.layer import LayerError
from zensols.deeplearn.batch import (
    Batch,
    BatchMetadata,
    BatchFieldMetadata,
    MetadataNetworkSettings,
)
from zensols.deepnlp.vectorize import FeatureDocumentVectorizerManager
from zensols.deepnlp.embed import WordEmbedModel
from zensols.deepnlp.vectorize import (
    TextFeatureType,
    FeatureDocumentVectorizer,
    EmbeddingFeatureVectorizer,
)

logger = logging.getLogger(__name__)


[docs] class EmbeddingLayer(DebugModule, Deallocatable): """A class used as an input layer to provide word embeddings to a deep neural network. **Important**: you must always check for attributes in :meth:`~zensols.persist.dealloc.Deallocatable.deallocate` since it might be called more than once (i.e. from directly deallocating and then from the factory). **Implementation note**: No datacasses are usable since pytorch is picky about initialization order. """
[docs] def __init__(self, feature_vectorizer_manager: FeatureDocumentVectorizerManager, embedding_dim: int, sub_logger: logging.Logger = None, trainable: bool = False): """Initialize. :param feature_vectorizer_manager: the feature vectorizer manager that manages this instance :param embedding_dim: the vector dimension of the embedding :param trainable: ``True`` if the embedding layer is to be trained """ super().__init__(sub_logger) self.feature_vectorizer_manager = feature_vectorizer_manager self.embedding_dim = embedding_dim self.trainable = trainable
@property def token_length(self): return self.feature_vectorizer_manager.token_length @property def torch_config(self): return self.feature_vectorizer_manager.torch_config
[docs] def deallocate(self): super().deallocate() if hasattr(self, 'emb'): if logger.isEnabledFor(logging.DEBUG): em = '<deallocated>' if hasattr(self, 'embed_model'): em = self.embed_model.name self._debug(f'deallocating: {em} and {type(self.emb)}') self._try_deallocate(self.emb) del self.emb if hasattr(self, 'embed_model'): del self.embed_model
[docs] class TrainableEmbeddingLayer(EmbeddingLayer): """A non-frozen embedding layer that has grad on parameters. """
[docs] def reset_parameters(self): if self.trainable: self.emb.load_state_dict({'weight': self.vecs})
def _get_emb_key(self, prefix: str): return f'{prefix}emb.weight'
[docs] def state_dict(self, *args, destination=None, prefix='', keep_vars=False): state = super().state_dict( *args, destination=destination, prefix=prefix, keep_vars=keep_vars) if logger.isEnabledFor(logging.DEBUG): self._debug(f'state_dict: trainable: {self.trainable}') if not self.trainable: emb_key = self._get_emb_key(prefix) if logger.isEnabledFor(logging.DEBUG): self._debug(f'state_dict: embedding key: {emb_key}') if emb_key is not None: if emb_key not in state: raise ModelError(f'No key {emb_key} in {state.keys()}') arr = state[emb_key] if arr is not None: if logger.isEnabledFor(logging.DEBUG): self._debug(f'state_dict: emb state: {arr.shape}') assert arr.shape == self.embed_model.matrix.shape state[emb_key] = None return state
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): if not self.trainable: emb_key = self._get_emb_key(prefix) if logger.isEnabledFor(logging.DEBUG): self._debug(f'load_state_dict: {emb_key}') if emb_key is not None: state_dict[emb_key] = self.vecs super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
[docs] @dataclass class EmbeddingNetworkSettings(MetadataNetworkSettings): """A utility container settings class for models that use an embedding input layer that inherit from :class:`.EmbeddingNetworkModule`. """ embedding_layer: EmbeddingLayer = field() """The word embedding layer used to vectorize."""
[docs] def get_module_class_name(self) -> str: return __name__ + '.EmbeddingNetworkModule'
@dataclass class _EmbeddingContainer(object): """Contains the mathcing of vectorizer, embedding_layer and field mapping. """ field_meta: BatchFieldMetadata = field() """The mapping that has the batch attribute name for the embedding.""" vectorizer: FeatureDocumentVectorizer = field() """The vectorizer used to encode the batch data.""" embedding_layer: EmbeddingLayer = field() """The word embedding layer used to vectorize.""" @property def dim(self) -> int: """The embedding's dimension.""" return self.embedding_layer.embedding_dim @property def attr(self) -> str: """The attribute name of the layer's mapping.""" return self.field_meta.field.attr def get_embedding_tensor(self, batch: Batch) -> Tensor: """Get the embedding (or indexes depending on how it was vectorize).""" return batch[self.attr] def __str__(self) -> str: return self.attr def __repr__(self) -> str: return self.__str__()
[docs] class EmbeddingNetworkModule(BaseNetworkModule): """An module that uses an embedding as the input layer. This class uses an instance of :class:`.EmbeddingLayer` provided by the network settings configuration for resolving the embedding during the *forward* phase. The following attributes are created and/or set during initialization: * ``embedding`` the :class:`.EmbeddingLayer` instance used get the input embedding tensors * ``embedding_attribute_names`` the name of the word embedding vectorized feature attribute names (usually one, but possible to have more) * ``embedding_output_size`` the outpu size of the embedding layer, note this includes any features layered/concated given in all token level vectorizer's configuration * ``join_size`` if a join layer is to be used, this has the size of the part of the join layer that will have the document level features * ``token_attribs`` the token level feature names (see :meth:`forward_token_features`) * ``doc_attribs`` the doc level feature names (see :meth:`forward_document_features`) The initializer adds additional attributes conditional on the :class:`.EmbeddingNetworkSettings` instance's :obj:`~zensols.deeplearn.batch.meta.MetadataNetworkSettings.batch_metadata` property (type :class:`~zensols.deeplearn.batch.meta.BatchMetadata`). For each meta data field's vectorizer that extends class :class:`.FeatureDocumentVectorizer` the following is set on this instance based on the value of ``feature_type`` (of type :class:`.TextFeatureType`): * :obj:`~.TextFeatureType.TOKEN`: ``embedding_output_size`` is increased by the vectorizer's shape * :obj:`~.TextFeatureType.DOCUMENT`: ``join_size`` is increased by the vectorizer's shape Fields can be filtered by passing a filter function to the initializer. See :meth:`__init__` for more information. """ MODULE_NAME = 'embed'
[docs] def __init__(self, net_settings: EmbeddingNetworkSettings, module_logger: logging.Logger = None, filter_attrib_fn: Callable[[BatchFieldMetadata], bool] = None): """Initialize the embedding layer. :param net_settings: the embedding layer configuration :param logger: the logger to use for the forward process in this layer :param filter_attrib_fn: if provided, called with a :class:`.BatchFieldMetadata` for each field returning ``True`` if the batch field should be retained and used in the embedding layer (see class docs); if ``None`` all fields are considered """ super().__init__(net_settings, module_logger) self.embedding_output_size: int = 0 self.join_size: int = 0 self.token_attribs: List[str] = [] self.doc_attribs: List[str] = [] self._embedding_containers: List[_EmbeddingContainer] = [] self._embedding_layers = self._map_embedding_layers() self._embedding_sequence = nn.Sequential( *tuple(self._embedding_layers.values())) field: BatchFieldMetadata meta: BatchMetadata = self.net_settings.batch_metadata fba: Dict[str, BatchFieldMetadata] = meta.fields_by_attribute if logger.isEnabledFor(logging.DEBUG): self._debug(f'batch field metadata: {fba}') for name in sorted(fba.keys()): field_meta: BatchFieldMetadata = fba[name] if filter_attrib_fn is not None and \ not filter_attrib_fn(field_meta): if logger.isEnabledFor(logging.DEBUG): self._debug(f'skipping: {name}') continue vec: FeatureVectorizer = field_meta.vectorizer if logger.isEnabledFor(logging.DEBUG): self._debug(f'{name} -> {field_meta}') if isinstance(vec, FeatureDocumentVectorizer): try: self._add_field(vec, field_meta) except Exception as e: raise ModelError( f'Could not create field {field_meta}: {e}') from e if len(self._embedding_containers) == 0: raise LayerError('No embedding vectorizer feature type found')
def _map_embedding_layers(self) -> Dict[int, EmbeddingLayer]: """Return a mapping of embedding layers configured using their in memory location as keys. """ els: Tuple[EmbeddingLayer] els = self.net_settings.embedding_layer if not isinstance(els, (tuple, list)): els = [els] return {id(el.embed_model): el for el in els} def _add_field(self, vec: FeatureDocumentVectorizer, field_meta: BatchFieldMetadata): """Add a batch metadata field and it's respective vectorizer to class member datastructures. """ attr = field_meta.field.attr if vec.feature_type == TextFeatureType.TOKEN: if logger.isEnabledFor(logging.DEBUG): self._debug(f'adding tok type {attr}: {vec.shape[2]}') self.embedding_output_size += vec.shape[2] self.token_attribs.append(attr) elif vec.feature_type == TextFeatureType.DOCUMENT: if logger.isEnabledFor(logging.DEBUG): self._debug(f'adding doc type {attr} ' + f'({field_meta.shape}/{vec.shape})') self.join_size += field_meta.shape[1] self.doc_attribs.append(attr) elif vec.feature_type == TextFeatureType.EMBEDDING: if logger.isEnabledFor(logging.DEBUG): self._debug(f'adding embedding: {attr}') embedding_layer: EmbeddingLayer = \ self._embedding_layers.get(id(vec.embed_model)) if embedding_layer is None: raise ModelError(f'No embedding layer found for {attr}') if self.logger.isEnabledFor(logging.INFO): we_model: WordEmbedModel = embedding_layer.embed_model self.logger.info(f'embeddings: {we_model.name}') ec = _EmbeddingContainer(field_meta, vec, embedding_layer) self._embedding_containers.append(ec) self.embedding_output_size += ec.dim
[docs] def get_embedding_tensors(self, batch: Batch) -> Tuple[Tensor]: """Get the embedding tensors (or indexes depending on how it was vectorize) from a batch. :param batch: contains the vectorized embeddings :return: the vectorized embedding as tensors, one for each embedding """ return tuple(map(lambda ec: batch[ec.attr], self._embedding_containers))
@property def embedding_dimension(self) -> int: """Return the dimension of the embeddings, which doesn't include any additional token or document features potentially added. """ return sum(map(lambda ec: ec.dim, self._embedding_containers))
[docs] def vectorizer_by_name(self, name: str) -> FeatureVectorizer: """Utility method to get a vectorizer by name. :param name: the name of the vectorizer as given in the vectorizer manager """ meta: BatchMetadata = self.net_settings.batch_metadata field_meta: BatchFieldMetadata = meta.fields_by_attribute[name] vec: FeatureVectorizer = field_meta.vectorizer return vec
def _forward(self, batch: Batch) -> Tensor: if self.logger.isEnabledFor(logging.DEBUG): self._debug(f'batch: {batch}') x = self.forward_embedding_features(batch) x = self.forward_token_features(batch, x) x = self.forward_document_features(batch, x) return x def _forward_embedding_layer(self, ec: _EmbeddingContainer, batch: Batch) -> Tensor: decoded: bool = False is_tok_vec: bool = isinstance(ec.vectorizer, EmbeddingFeatureVectorizer) x: Tensor = ec.get_embedding_tensor(batch) self._shape_debug('input', x) if self.logger.isEnabledFor(logging.DEBUG): self._debug(f'vectorizer type: {type(ec.vectorizer)}') if is_tok_vec: decoded = ec.vectorizer.decode_embedding if self.logger.isEnabledFor(logging.DEBUG): self._debug(f'is embedding already decoded: {decoded}') if not decoded: x = ec.embedding_layer(x) return x
[docs] def forward_embedding_features(self, batch: Batch) -> Tensor: """Use the embedding layer return the word embedding tensors. """ arr: Tensor arrs: List[Tensor] = [] ec: _EmbeddingContainer for ec in self._embedding_containers: x: Tensor = self._forward_embedding_layer(ec, batch) self._shape_debug(f'decoded sub embedding ({ec}):', x) arrs.append(x) if len(arrs) == 1: arr = arrs[0] else: arr = torch.concat(arrs, dim=-1) self._shape_debug(f'decoded concat embedding ({ec}):', arr) return arr
[docs] def forward_token_features(self, batch: Batch, x: Tensor = None) -> Tensor: """Concatenate any token features given by the vectorizer configuration. :param batch: contains token level attributes to concatenate to ``x`` :param x: if given, the first tensor to be concatenated """ self._shape_debug('forward token features', x) arrs = [] if x is not None: self._shape_debug('adding passed token features', x) arrs.append(x) for attrib in self.token_attribs: feats = batch.attributes[attrib] self._shape_debug(f"token attrib '{attrib}'", feats) arrs.append(feats) if len(arrs) == 1: x = arrs[0] elif len(arrs) > 1: self._debug(f'concating {len(arrs)} token features') x = torch.cat(arrs, 2) self._shape_debug('token concat', x) return x
[docs] def forward_document_features(self, batch: Batch, x: Tensor = None, include_fn: Callable = None) -> Tensor: """Concatenate any document features given by the vectorizer configuration. """ self._shape_debug('forward document features', x) arrs = [] if x is not None: arrs.append(x) for attrib in self.doc_attribs: if include_fn is not None and not include_fn(attrib): continue st = batch.attributes[attrib] self._shape_debug(f'doc attrib {attrib}', st) arrs.append(st) if len(arrs) > 0: x = torch.cat(arrs, 1) self._shape_debug('doc concat', x) return x