"""This file contains a stash used to load an embedding layer. It createsfeatures in batches of matrices and persists matrix only (sans features) forefficient retrival."""from__future__importannotations__author__='Paul Landes'fromtypingimportTuple,Iterable,List,Union,ClassVar,TYPE_CHECKINGifTYPE_CHECKING:from..transformer.embedimportTransformerEmbeddingfromdataclassesimportdataclass,fieldimportloggingfromitertoolsimportchainimporttorchfromtorchimportTensorfromzensols.configimportDictablefromzensols.persistimportpersisted,Primeablefromzensols.deeplearn.vectorizeimportFeatureContext,TensorFeatureContextfromzensols.nlpimportFeatureToken,FeatureDocument,FeatureSentencefromzensols.deepnlp.embedimportWordEmbedModelfromzensols.deepnlp.vectorizeimportTextFeatureTypefrom.importFoldingDocumentVectorizerlogger=logging.getLogger(__name__)
[docs]@dataclassclassEmbeddingFeatureVectorizer(FoldingDocumentVectorizer,Primeable,Dictable):"""Vectorize a :class:`~zensols.nlp.container.FeatureDocument` as a vector of embedding indexes. Later, these indexes are used in a :class:`~zensols.deepnlp.layer.embed.EmbeddingLayer` to create the input word embedding during execution of the model. """embed_model:Union[WordEmbedModel,TransformerEmbedding]=field()"""The word vector model. Types for this value include: * :class:`~zensols.deepnlp.embed.domain.WordEmbedModel` * :class:`~zensols.deepnlp.transformer.embed.TransformerEmbedding` """decode_embedding:bool=field(default=False)"""Whether or not to decode the embedding during the decode phase, which is helpful when caching batches; otherwise, the data is decoded from indexes to embeddings each epoch. Note that this option and functionality can not be obviated by that implemented with the :obj:`encode_transformed` attribute. The difference is over whether or not more work is done on during decoding rather than encoding. An example of when this is useful is for large word embeddings (i.e. Google 300D pretrained) where the the index to tensor embedding transform is done while decoding rather than in the `forward` so it's not done for every epoch. """def_get_shape(self)->Tuple[int,int]:returnself.manager.token_length,self.embed_model.vector_dimension
[docs]@dataclassclassWordVectorEmbeddingFeatureVectorizer(EmbeddingFeatureVectorizer):"""Vectorize sentences using an embedding model (:obj:`embed_model`) of type :class:`.WordEmbedModel`. The encoder returns the indicies of the word embedding for each token in the input :class:`.FeatureDocument`. The decoder returns the corresponding word embedding vectors if :obj:`decode_embedding` is ``True``. Otherwise it returns the same indicies, which later used by the embedding layer (usually :class:`~zensols.deepnlp.layer.EmbeddingLayer`). """DESCRIPTION:ClassVar[str]='word vector document embedding'FEATURE_TYPE:ClassVar[TextFeatureType]=TextFeatureType.EMBEDDINGtoken_feature_id:str=field(default='norm')"""The :class:`~zensols.nlp.tok.FeatureToken` attribute used to index the embedding vectors. """def_encode(self,doc:FeatureDocument)->FeatureContext:emodel:WordEmbedModel=self.embed_modeltw:int=self.manager.get_token_length(doc)sents:Tuple[FeatureSentence]=doc.sentsshape:Tuple[int,int]=(len(sents),tw)tfid:str=self.token_feature_idarr:Tensor=self.torch_config.empty(shape,dtype=torch.long)iflogger.isEnabledFor(logging.DEBUG):logger.debug(f'using token length: {tw} with shape: {shape}, '+f'sents: {len(sents)}')row:intsent:FeatureSentenceforrow,sentinenumerate(sents):tokens:List[FeatureToken]=sent.tokens[0:tw]slen:int=len(tokens)iflogger.isEnabledFor(logging.DEBUG):logger.debug(f'row: {row}, '+'toks: '+' '.join(map(lambdax:x.norm,tokens)))tokens=list(map(lambdat:getattr(t,tfid),tokens))ifslen<tw:tokens+=[WordEmbedModel.ZERO]*(tw-slen)fori,tokinenumerate(tokens):arr[row][i]=emodel.word2idx_or_unk(tok)returnTensorFeatureContext(self.feature_id,arr)@property@persisted('_vectors')defvectors(self)->Tensor:embed_model:WordEmbedModel=self.embed_modelreturnembed_model.to_matrix(self.torch_config)def_decode(self,context:FeatureContext)->Tensor:x:Tensor=super()._decode(context)iflogger.isEnabledFor(logging.DEBUG):logger.debug(f'indexes: {x.shape} ({x.dtype}), '+f'will decode in vectorizer: {self.decode_embedding}')ifself.decode_embedding:iflogger.isEnabledFor(logging.DEBUG):logger.debug(f'decoding using: {self.decode_embedding}')src_vecs:Tensor=self.vectorsbatches:List[Tensor]=[]vecs:List[Tensor]=[]batch_idx:Tensorforbatch_idxinx:idxt:intforidxtinbatch_idx:vecs.append(src_vecs[idxt])batches.append(torch.stack(vecs))vecs.clear()x=torch.stack(batches)returnx