"""Prediction mapper support for NLP applications."""__author__='Paul Landes'fromtypingimportTuple,List,Dict,Iterablefromdataclassesimportdataclass,fieldfromitertoolsimportchainaschimportnumpyasnpfromzensols.configimportSettingsfromzensols.nlpimportFeatureSentence,FeatureDocumentfromzensols.deeplearn.vectorizeimportCategoryEncodableFeatureVectorizerfromzensols.deeplearn.modelimportPredictionMapperfromzensols.deeplearn.resultimportResultsContainerfromzensols.deepnlp.vectorizeimportFeatureDocumentVectorizerManagerfrom.importLabeledFeatureDocument
[docs]@dataclassclassClassificationPredictionMapper(PredictionMapper):"""A prediction mapper for text classification. This mapper works at any level (document, sentence, token). """vec_manager:FeatureDocumentVectorizerManager=field()"""The vectorizer manager used to parse and get the label vectorizer."""label_feature_id:str=field()"""The feature ID for the label vectorizer."""pred_attribute:str=field(default='pred')"""The prediction attribute to set on the :class:`.FeatureDocument` returned from :meth:`map_results`. """softmax_logit_attribute:str=field(default='softmax_logit')"""The softmax of the logits attribute to set on the :class:`.FeatureDocument` returned from :meth:`map_results`. :see: `On Calibration of Modern Neural Networks <https://arxiv.org/abs/1706.04599>`_ """def__post_init__(self):super().__post_init__()self._docs:List[FeatureDocument]=[]@propertydeflabel_vectorizer(self)->CategoryEncodableFeatureVectorizer:"""The label vectorizer used to map classes in :meth:`get_classes`."""returnself.vec_manager[self.label_feature_id]def_create_features(self,sent_text:str)->Tuple[FeatureDocument,...]:doc:FeatureDocument=self.vec_manager.parse(sent_text)self._docs.append(doc)return[doc]def_map_classes(self,result:ResultsContainer)->List[List[str]]:"""Return the label string values for indexes ``nominals``. :param nominals: the integers that map to the respective string class; each tuple is a batch, and each item in the iterable is a data point :return: a list for every tuple in ``nominals`` """vec:CategoryEncodableFeatureVectorizer=self.label_vectorizernominals:List[np.ndarray]=result.batch_predictionsreturnlist(map(lambdacl:vec.get_classes(cl).tolist(),nominals))
[docs]defmap_results(self,result:ResultsContainer)-> \
Tuple[LabeledFeatureDocument,...]:"""Map class predictions, logits, and documents generated during use of this instance. Each data point is aggregated across batches. :return: a :class:`.Settings` instance with ``classess``, ``logits`` and ``docs`` attributes """class_groups:List[List[str]]=self._map_classes(result)classes:Iterable[str]=ch.from_iterable(class_groups)blogits:Iterable[np.ndarray]=ch.from_iterable(result.batch_outputs)docs:List[FeatureDocument]=self._docslabels:List[str]=self.label_vectorizer.label_encoder.classes_cl:strdoc:FeatureDocumentlogits:np.ndarrayforcl,doc,logitsinzip(classes,docs,blogits):conf:np.ndarray=np.exp(logits)/sum(np.exp(logits))sms:Dict[str,np.ndarray]=dict(zip(labels,conf))setattr(doc,self.pred_attribute,cl)setattr(doc,self.softmax_logit_attribute,sms)returntuple(docs)
[docs]@dataclassclassSequencePredictionMapper(ClassificationPredictionMapper):"""Predicts sequences as a :class:`~zensols.config.serial.Settings` with keys `classes` as the token level predictions and `docs` containing the parsed documents from the sentence text. """def_create_features(self,sent_text:str)->Tuple[FeatureSentence,...]:doc:FeatureDocument=self.vec_manager.parse(sent_text)self._docs.append(doc)returndoc.sents