"""Provides a class to graph the results."""__author__='Paul Landes'fromtypingimportList,Tuplefromdataclassesimportdataclass,fieldimportloggingfrompathlibimportPathfromzensols.utilimportAPIErrorfromzensols.deeplearnimportDatasetSplitTypefrom.importModelResult,DatasetResultlogger=logging.getLogger(__name__)
[docs]@dataclassclassModelResultGrapher(object):"""Graphs the an instance of :class:`.ModelResult`. This creates subfigures, one for each of the results given as input to :meth:`plot`. """name:str=field(default=None)"""The name that goes in the title of the graph."""figsize:Tuple[int,int]=field(default=(15,5))"""the size of the top level figure (not the panes)"""split_types:List[DatasetSplitType]=field(default=None)"""The splits to graph (list of size 2); defaults to ``[DatasetSplitType.train, DatasetSplitType.validation]``. """title:str=field(default=None)"""The title format used to create each sub pane graph."""save_path:Path=field(default=None)"""Where the plot is saved."""def__post_init__(self):ifself.split_typesisNone:self.split_types=[DatasetSplitType.train,DatasetSplitType.validation]ifself.titleisNone:self.title=('Result {r.name} '+'(lr={learning_rate:e}, '+'{r.last_test.converged_epoch.metrics})')def_render_title(self,cont:ModelResult)->str:lr=cont.model_settings['learning_rate']returnself.title.format(**{'r':cont,'learning_rate':lr})
[docs]defplot_loss(self,containers:List[ModelResult]):from.figimportFigurefrom.plotsimportLossPlotname:str=containers[0].nameifself.nameisNoneelseself.nametitle:str=f'Training and Validation Learning Rates: {name}'fig=Figure(name=title,title_font_size=14,width=self.figsize[0],height=self.figsize[1]*len(containers))fig.path=self.save_pathrow:intcont:ModelResultforrow,continenumerate(containers):iflogger.isEnabledFor(logging.DEBUG):logger.debug(f'plotting {cont}')results:Tuple[Tuple[str,DatasetResult],...]=tuple(map(lambdan:(n.name.capitalize(),cont.dataset_result[n]),self.split_types))plot=LossPlot(title=self._render_title(cont),row=row)name:strresult:DatasetResultforname,resultinresults:plot.add(name,result.losses)fig.add_plot(plot)self._fig=fig
[docs]defshow(self):"""Render and display the plot."""self._fig.show()
[docs]defsave(self):"""Save the plot to disk."""ifnothasattr(self,'_fig'):raiseAPIError('Must first call :meth:`plot` before saving')iflogger.isEnabledFor(logging.INFO):logger.info(f'saving results graph to {self.save_path}')self._fig.save()self._fig.clear()