"""A :class:`zensols.nlp.container.FeatureDocument` decorator that populatessentence and token embeddings."""__author__='Paul Landes'fromtypingimportOptional,Union,Listfromdataclassesimportdataclass,fieldimportnumpyasnpimporttorchfromtorchimportTensorfromzensols.deeplearnimportTorchConfigfromzensols.nlpimport(FeatureToken,FeatureSentence,FeatureDocument,FeatureDocumentDecorator)from.importWordEmbedModel
[docs]@dataclassclassWordEmbedDocumentDecorator(FeatureDocumentDecorator):"""Populates sentence and token embeddings in the documents. Token's have shape ``(1, d)`` where ``d`` is the embeddingn dimsion, and the first is always 1 to be compatible with word piece embeddings populated by :class:`..transformer.WordPieceDocumentDecorator`. :see: :class:`.WordEmbedModel` """model:WordEmbedModel=field()"""The word embedding model for populating tokens and sentences."""torch_config:Optional[TorchConfig]=field(default=None)"""The Torch configuration to allocate the embeddings from either the GPU or the CPU. If ``None``, then Numpy :class:`numpy.ndarray` arrays are used instead of :class:`torch.Tensor`. """token_embeddings:bool=field(default=True)"""Whether to add :class:`.WordPieceFeatureToken.embeddings`. """sent_embeddings:bool=field(default=True)"""Whether to add class:`.WordPieceFeatureSentence.embeddings`. """skip_oov:bool=field(default=False)"""Whether to skip out-of-vocabulary tokens that have no embeddings."""def_add_sent_embedding(self,sent:FeatureSentence):use_np:bool=self.torch_configisNoneadd_tok_emb:bool=self.token_embeddingsmodel:WordEmbedModel=self.model# our embedding will be a numpy array when no torch config is providedemb:Union[np.ndarray,Tensor]sembs:List[Union[np.ndarray,Tensor]]=[]ifuse_np:# already a numpy arrayemb=model.matrixelse:# convert to a torch tensor based on our configuration (i.e. device)emb=model.to_matrix(self.torch_config)tok:FeatureTokenfortokinsent.token_iter():norm:str=tok.normidx:int=model.word2idx(norm)ifnotself.skip_oovoridxisnotNone:ifidxisNone:idx=model.unk_idxvec:Union[np.ndarray,Tensor]=emb[idx]sembs.append(vec)ifadd_tok_emb:ifuse_np:vec=np.expand_dims(vec,axis=0)else:vec=vec.unsqueeze(axis=0)tok.embedding=vec# sentinel embeddings are the centroid for non-contextual embeddingsiflen(sembs)>0andself.sent_embeddings:ifuse_np:sent.embedding=np.stack(sembs).mean(axis=0)else:sent.embedding=torch.stack(sembs).mean(axis=0)