Source code for zensols.deeplearn.model.wgtexecutor
"""A class that weighs labels non-uniformly."""__author__='Paul Landes'fromtypingimportDict,Anyfromdataclassesimportdataclass,field,InitVarimportloggingimportcollectionsfrompathlibimportPathimporttorchfromzensols.utilimporttimefromzensols.persistimportpersisted,PersistedWorkfrom.importModelExecutorlogger=logging.getLogger(__name__)
[docs]@dataclassclassWeightedModelExecutor(ModelExecutor):"""A class that weighs labels non-uniformly. This class uses invert class sampling counts to help the minority label. """weighted_split_name:str=field(default='train')"""The split name used to re-weight labels."""weighted_split_path:InitVar[Path]=field(default=None)"""The path to the cached weithed labels."""use_weighted_criterion:bool=field(default=True)"""If ``True``, use the class weights in the initializer of the criterion. Setting this to ``False`` effectively disables this class. """def__post_init__(self,weighted_split_path:Path):super().__post_init__()ifweighted_split_pathisNone:path='_label_counts'else:file_name=f'weighted-labels-{self.weighted_split_name}.dat'path=weighted_split_path/file_nameself._label_counts=PersistedWork(path,self)
@persisted('_label_counts')defget_label_counts(self)->Dict[int,int]:stash=self.dataset_stash.splits[self.weighted_split_name]label_counts=collections.defaultdict(lambda:0)batches=tuple(stash.values())forbatchinbatches:forlabelinbatch.get_labels():label_counts[label.item()]+=1forbatchinbatches:batch.deallocate()returndict(label_counts)@persisted('_class_weighs')defget_class_weights(self)->torch.Tensor:"""Compute invert class sampling counts to return the weighted class. """counts=self.get_label_counts().items()counts=map(lambdax:x[1],sorted(counts,key=lambdax:x[0]))counts=self.torch_config.from_iterable(counts)returncounts.mean()/counts
[docs]defget_label_statistics(self)->Dict[str,Dict[str,Any]]:"""Return a dictionary whose keys are the labels and values are dictionaries containing statistics on that label. """counts=self.get_label_counts()weights=self.get_class_weights().cpu().numpy()batch=next(iter(self.dataset_stash.values()))vec=batch.batch_stash.get_label_feature_vectorizer(batch)classes=vec.get_classes(range(weights.shape[0]))return{c[0]:{'index':c[1],'count':counts[c[1]],'weight':weights[c[1]]}forcinzip(classes,range(weights.shape[0]))}
def_create_criterion(self)->torch.optim.Optimizer:resolver=self.config_factory.class_resolvercriterion_class_name=self.model_settings.criterion_class_nameiflogger.isEnabledFor(logging.DEBUG):logger.debug(f'criterion: {criterion_class_name}')criterion_class=resolver.find_class(criterion_class_name)withtime('weighted classes'):class_weights=self.get_class_weights()iflogger.isEnabledFor(logging.INFO):logger.info(f'using class weights: {class_weights}')ifself.use_weighted_criterion:inst=criterion_class(weight=class_weights)else:inst=criterion_class()returninst