Source code for zensols.calamr.annotate
"""Contain a class to add embeddings to AMR feature documents.
"""
__author__ = 'Paul Landes'
from typing import Dict, Tuple, Any
from dataclasses import dataclass, field
import logging
from zensols.util import time
from zensols.persist import DelegateStash, PrimeableStash
from zensols.amr import AmrFailure, AmrSentence, AmrFeatureDocument
from zensols.amr.annotate import (
AnnotatedAmrFeatureDocumentFactory,
AnnotatedAmrDocument, AnnotatedAmrSectionDocument,
)
from zensols.deepnlp.transformer import (
WordPieceFeatureDocumentFactory, WordPieceFeatureDocument
)
logger = logging.getLogger(__name__)
@dataclass
class _EmbeddingPopulator(object):
"""Base class for adding word piece doc embeddings.
"""
word_piece_doc_factory: WordPieceFeatureDocumentFactory = field()
"""The feature document factory that populates embeddings."""
def _populate_embeddings(self, doc: AmrFeatureDocument, doc_id: str = None):
"""Adds the transformer sentinel embeddings to the document."""
if self.word_piece_doc_factory is not None and not doc.is_failure:
doc_id = doc.amr.get_doc_id() if doc_id is None else doc_id
try:
with time(f'populated embedding of document {doc_id}'):
wpdoc: WordPieceFeatureDocument = \
self.word_piece_doc_factory(doc)
wpdoc.copy_embedding(doc)
except Exception as e:
msg: str = f"Could not process adhoc document: '{doc_id}': {e}"
sent: AmrSentence
for sent in doc.amr.sents:
fail = AmrFailure(exception=e, thrower=self, message=msg)
sent.failure = fail
[docs]
@dataclass
class AddEmbeddingsFeatureDocumentStash(DelegateStash, PrimeableStash,
_EmbeddingPopulator):
"""Add embeddings to AMR feature documents. Embedding population is
disabled by configuring :obj:`word_piece_doc_factory` as ``None``.
"""
[docs]
def load(self, name: str) -> AmrFeatureDocument:
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"loading '{name}'")
doc: AmrFeatureDocument = self.delegate.load(name)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"loaded '{name}' -> {doc}")
self._populate_embeddings(doc, name)
return doc
[docs]
def get(self, name: str, default: Any = None) -> Any:
item = self.load(name)
item = default if item is None else item
return item
[docs]
@dataclass
class CalamrAnnotatedAmrFeatureDocumentFactory(
AnnotatedAmrFeatureDocumentFactory, _EmbeddingPopulator):
"""Adds wordpiece embeddings to
:class:`~zensols.amr.container.AmrFeatureDocument` instances.
"""
[docs]
def to_annotated_doc(self, doc: AmrFeatureDocument) -> AmrFeatureDocument:
fdoc = super().to_annotated_doc(doc)
self._populate_embeddings(fdoc)
return fdoc
[docs]
def from_dict(self, data: Dict[str, str]) -> AmrFeatureDocument:
fdoc = super().from_dict(data)
if self.word_piece_doc_factory is not None:
self._populate_embeddings(fdoc)
return fdoc
[docs]
@dataclass
class ProxyReportAnnotatedAmrDocument(AnnotatedAmrDocument):
"""Overrides the sections property to skip duplicate summary sentences also
found in the body.
"""
@property
def sections(self) -> Tuple[AnnotatedAmrSectionDocument]:
"""The sentences that make up the body of the document."""
def filter_sents(s: AmrSentence) -> bool:
return s.text not in sum_sents
sum_sents = set(map(lambda s: s.text, self.summary))
secs = super().sections
for sec in secs:
sec.sents = tuple(filter(filter_sents, sec.sents))
return secs