"""Wrap the :mod:`amr_coref` module for AMR Co-refernce resolution."""__author__='Paul Landes'fromtypingimportDict,List,Tuple,Unionfromdataclassesimportdataclass,fieldimportloggingfrompathlibimportPathimporttextwrapastwimportplatformimporttorchfromzensols.utilimporttime,Hasherfromzensols.persistimportpersisted,Stash,DictionaryStashfromzensols.installimportInstallerfromamr_coref.coref.inferenceimportInferencefrom.importAmrFailure,AmrFeatureDocumentlogger=logging.getLogger(__name__)
[docs]@dataclassclassCoreferenceResolver(object):"""Resolve coreferences in AMR graphs. """installer:Installer=field()"""The :mod:`amr_coref` module's coreference module installer."""stash:Stash=field(default_factory=DictionaryStash)"""The stash used to cache results. It takes a while to inference but the results in memory size are small. """use_multithreading:bool=field(default=True)"""By default, multithreading is enabled for Linux systems. However, an error is raised when invoked from a child thread. Set to ``False`` to off multithreading for coreference resolution. """robust:bool=field(default=True)"""Whether to robustly deal with exceptions in the coreference model. If ``True``, instances of :class:`.AmrFailure` are stored in the stash and empty coreferences used for caught errors. """def_use_multithreading(self)->bool:returnself.use_multithreadingand \
notplatform.system()!='Linux'@property@persisted('_model')defmodel(self)->Inference:"""The :mod:`amr_coref` coreference model."""use_multithreading:bool=Trueifnotself._use_multithreading():iflogger.isEnabledFor(logging.INFO):logger.info('turning off AMR coref multithreading for '+f'platform: {platform.system()}')use_multithreading=Falseself.installer()model_path:Path=self.installer.get_singleton_path()device=Noneiftorch.cuda.is_available()else'cpu'returnInference(str(model_path),device=device,use_multithreading=use_multithreading)def_resolve(self,doc:AmrFeatureDocument)-> \
Dict[str,List[Tuple[int,str]]]:"""Use the coreference model and return the output."""iflogger.isEnabledFor(logging.INFO):logger.info(f'resolving coreferences for {doc}')model:Inference=self.modelgraph_strs=tuple(map(lambdas:s.amr.graph_string,doc.sents))iflogger.isEnabledFor(logging.DEBUG):logger.debug('resolving coreferences graph:')forgsingraph_strs:logger.debug(f' {gs}')withtime(f'resolved {len(doc)} sentence coreferences'):returnmodel.coreference(graph_strs)def_create_key(self,doc:AmrFeatureDocument)->str:"""Create a unique key based on the text of the sentences of the document. """hasher=Hasher()forsentindoc:hasher.update(sent.text)returnhasher()
[docs]defclear(self):"""Clear the stash cashe."""self.stash.clear()
def__call__(self,doc:AmrFeatureDocument):"""Return the coreferences of the AMR sentences of the document. If the document is cashed in :obj:`stash` use that. Otherwise use the model to compute it and return it. :param doc: the document used in the model to perform coreference resolution :return: the coreferences tuples as ``(<document index>, <variable>)`` """ref:Dict[str,List[Tuple[int,str]]]key:str=self._create_key(doc)ref:Union[AmrFailure,Dict[str,List[Tuple[int,str]]]]= \
self.stash.load(key)ifrefisNone:try:ref=self._resolve(doc)exceptExceptionase:text:str=tw.shorten(doc.text,width=60)ref=AmrFailure(exception=e,message=f'could not co-reference <{text}>',sent=doc.text)logger.warning(str(ref))self.stash.dump(key,ref)ifnotisinstance(ref,AmrFailure):doc.coreference_relations=tuple(map(tuple,ref.values()))