Source code for zensols.nlp.combine

"""A class that combines features.

"""
__author__ = 'Paul Landes'

from typing import Set, List, Dict, Any
from dataclasses import dataclass, field
import logging
from . import (
    ParseError, TokenContainer, FeatureDocumentParser,
    FeatureDocument, FeatureSentence, FeatureToken,
)
from .parser import DecoratedFeatureDocumentParser

logger = logging.getLogger(__name__)


[docs] @dataclass class CombinerFeatureDocumentParser(DecoratedFeatureDocumentParser): """A class that combines features from two :class:`.FeatureDocumentParser` instances. Features parsed using each :obj:`source_parser` are optionally copied or overwritten on a token by token basis in the feature document parsed by this instance. The target tokens are sometimes added to or clobbered from the source, but not the other way around. """ source_parsers: List[FeatureDocumentParser] = field(default=None) """The language resource used to parse documents and create token attributes. """ validate_features: Set[str] = field(default_factory=set) """A set of features to compare across all tokens when copying. If any of the given features don't match, an mismatch token error is raised. """ yield_features: List[str] = field(default_factory=list) """A list of features to be copied (in order) if the target token is not set. """ yield_feature_defaults: Any = field(default=None) """A default value to use when no yielded value is found. If ``None``, do not add the feature if missing. """ overwrite_features: List[str] = field(default_factory=list) """A list of features to be copied/overwritten in order given in the list. """ overwrite_nones: bool = field(default=False) """Whether to write ``None`` for missing :obj:`overwrite_features`. This always write the *target* feature; if you only to write when the *source* is not set or missing, then use :obj:`yield_features`. """ include_detached_features: bool = field(default=True) """Whether to include copied (yielded or overwritten) features as listed detected features. This controls what is compared, cloned and for printed in :meth:`~zensols.config.writable.Writable.write`. :see: :obj:`.FeatureToken.default_detached_feature_ids` """ def _validate_features(self, target_tok: FeatureToken, source_tok: FeatureToken, context_container: TokenContainer): for f in self.validate_features: prim = getattr(target_tok, f) rep = getattr(source_tok, f) if prim != rep: raise ParseError( f'Mismatch tokens: {target_tok.text}({f}={prim}) ' + f'!= {source_tok.text}({f}={rep}) ' + f'in container: {context_container}') def _merge_tokens(self, target_tok: FeatureToken, source_tok: FeatureToken, context_container: TokenContainer): overwrite_nones: bool = self.overwrite_nones include_detached: bool = self.include_detached_features yield_default: Any = self.yield_feature_defaults if logger.isEnabledFor(logging.TRACE): logger.trace(f'merging tokens: {source_tok} ({type(source_tok)}) ' f'-> {target_tok} ({type(target_tok)})') self._validate_features(target_tok, source_tok, context_container) f: str for f in self.yield_features: targ = getattr(target_tok, f) if hasattr(target_tok, f) else None if logger.isEnabledFor(logging.TRACE): logger.trace(f'yield feature: {f}, target={targ}') if targ is None or targ == FeatureToken.NONE: src = getattr(source_tok, f) if hasattr(source_tok, f) else None if logger.isEnabledFor(logging.TRACE): logger.trace(f'yield feature: {f}, src={src}') if (src is None or src == FeatureToken.NONE) and \ yield_default is not None: src = yield_default if src is not None: if logger.isEnabledFor(logging.TRACE): logger.trace(f'{src} -> {target_tok.text}.{f}') setattr(target_tok, f, src) if include_detached and \ target_tok._detatched_feature_ids is not None: target_tok._detatched_feature_ids.add(f) for f in self.overwrite_features: if overwrite_nones: src = getattr(source_tok, f) else: src = source_tok.get_value(f) src = FeatureToken.NONE if src is None else src if logger.isEnabledFor(logging.TRACE): prev = target_tok.get_value(f) logger.trace( f'overwrite: {src} -> {prev} ({target_tok.text}.{f})') setattr(target_tok, f, src) if include_detached and \ target_tok._detatched_feature_ids is not None: target_tok._detatched_feature_ids.add(f) def _debug_sentence(self, sent: FeatureSentence, name: str): if logging.isEnabledFor(logging.DEBUG): logger.debug(f'{name}:') for i, tok in enumerate(sent.tokens): logger.debug(f' {i}: i={tok.i}, pos={tok.pos_}, ' + f'ent={tok.ent_}: {tok}') def _merge_sentence(self): if logger.isEnabledFor(logging.DEBUG): logger.debug(f'merging sentences: {self._source_sent.tokens} ' + f'-> {self._target_sent.tokens}') n_toks: int = 0 assert len(self._target_sent) == len(self._source_sent) for target_tok, source_tok in zip( self._target_sent, self._source_sent): assert target_tok.idx == source_tok.idx self._merge_tokens(target_tok, source_tok, self._target_sent) assert n_toks == len(self._target_sent) def _prepare_merge_doc(self): pass def _complete_merge_doc(self): pass def _merge_doc(self): if logger.isEnabledFor(logging.DEBUG): logger.debug(f'merging docs: {self._source_doc} -> ' + f'{self._target_doc}') for target_sent, source_sent in zip( self._target_doc, self._source_doc): self._target_sent = target_sent self._source_sent = source_sent self._prepare_merge_doc() try: self._merge_sentence() finally: del self._target_sent del self._source_sent self._complete_merge_doc() self._target_doc._combine_update(self._source_doc) def _merge_docs(self, target_doc: FeatureDocument, source_doc: FeatureDocument): self._target_doc = target_doc self._source_doc = source_doc try: self._merge_doc() finally: del self._target_doc del self._source_doc def _parse(self, parsed: Dict[int, FeatureDocument], text: str, *args, **kwargs) -> FeatureDocument: self._log_parse(text, logger) key: int = id(self.delegate) target_doc: FeatureDocument = parsed.get(key) if target_doc is None: if logger.isEnabledFor(logging.DEBUG): logger.debug(f'parsing with {self.delegate}') target_doc = self.delegate.parse(text, *args, **kwargs) parsed[key] = target_doc else: if logger.isEnabledFor(logging.DEBUG): logger.debug(f'resuing parsed doc from {self.delegate}') if self.source_parsers is None or len(self.source_parsers) == 0: logger.warning(f'No source parsers set on {self}, ' + 'which disables feature combining') else: for source_parser in self.source_parsers: source_doc: FeatureDocument if isinstance(source_parser, CombinerFeatureDocumentParser): source_doc = source_parser._parse( parsed, text, *args, **kwargs) else: source_doc = source_parser.parse(text, *args, **kwargs) if logger.isEnabledFor(logging.DEBUG): logger.debug(f'merging {source_parser} -> {self.delegate}') self._merge_docs(target_doc, source_doc) parsed[id(source_parser)] = source_doc self.decorate(target_doc) return target_doc
[docs] def parse(self, text: str, *args, **kwargs) -> FeatureDocument: return self._parse({}, text, *args, **kwargs)
[docs] @dataclass class MappingCombinerFeatureDocumentParser(CombinerFeatureDocumentParser): """Maps the source to respective tokens in the target document using spaCy artifacts. """ validate_features: Set[str] = field(default=frozenset({'idx'})) """A set of features to compare across all tokens when copying. If any of the given features don't match, an mismatch token error is raised. The default is the token's index in the document, which should not change in most cases. """ merge_sentences: bool = field(default=True) """If ``False`` ignore sentences and map everything at the token level. Otherwise, the same hierarchy mapping as the super class is used. This is useful when sentence demarcations are not aligned across source document parsers and this parser. """ def _merge_entities_by_token(self, target_tok, source_tok): """Add the source sentence entity spans to the target sentence. This is important so the original spaCy predicted entities are merged from the source to the target. """ tdoc, sdoc = self._target_doc, self._source_doc tsent = None try: tsent = sdoc.get_overlapping_sentences(source_tok.lexspan) except StopIteration: pass if tsent is not None: tsent = next(tdoc.get_overlapping_sentences(target_tok.lexspan)) skips = set(tsent._ents) for ent in tsent._ents: begin, end = ent if begin == source_tok.idx and ent not in skips: tsent._ents.append(ent) def _merge_token_containers(self, target_container: TokenContainer, rmap: Dict[int, FeatureToken]): if logger.isEnabledFor(logging.DEBUG): logger.debug(f'merge: {target_container}, mapping: {rmap}') include_detached: bool = self.include_detached_features visited: Set[FeatureToken] = set() targ_toks: Set[FeatureToken] = set(target_container.token_iter()) for target_tok in target_container.token_iter(): source_tok: FeatureToken = rmap.get(target_tok.idx) if logger.isEnabledFor(logging.TRACE): logger.trace(f'entry: {target_tok.idx}/{target_tok} ' + f'-> {source_tok}') if source_tok is not None: visited.add(source_tok) self._merge_tokens(target_tok, source_tok, target_container) if source_tok.ent_ != FeatureToken.NONE: self._merge_entities_by_token(target_tok, source_tok) for target_tok in (targ_toks - visited): fname: str for fname in self.overwrite_features: if not hasattr(target_tok, fname): setattr(target_tok, fname, FeatureToken.NONE) if include_detached and \ target_tok._detatched_feature_ids is not None: target_tok._detatched_feature_ids.add(fname) def _merge_sentence(self): if logger.isEnabledFor(logging.DEBUG): logger.debug(f'merging sentences: {self._source_sent.tokens} ' + f'-> {self._target_sent.tokens}') rmp: Dict[int, FeatureToken] = self._source_token_mapping self._merge_token_containers(self._target_sent, rmp) self._target_doc._combine_update(self._source_doc) def _prepare_merge_doc(self): if self.merge_sentences: self._source_token_mapping = self._source_doc.tokens_by_idx def _complete_merge_doc(self): if self.merge_sentences: del self._source_token_mapping def _merge_doc(self): if self.merge_sentences: super()._merge_doc() else: if logger.isEnabledFor(logging.DEBUG): logger.debug(f'merging docs: {self._source_doc} -> ' + f'{self._target_doc}') source_token_mapping = self._source_doc.tokens_by_idx self._merge_token_containers(self._target_doc, source_token_mapping) self._target_doc._combine_update(self._source_doc)