"""A split key container for leave-one-out dataset splits."""__author__='Paul Landes'fromtypingimportDict,Tuple,Setfromdataclassesimportdataclass,fieldimportsysimportrandomasrandimportitertoolsasitfromcollectionsimportdequefrompathlibimportPathfromzensols.persistimportpersisted,PersistedWork,Stashfrom.importDatasetError,SplitKeyContainer
[docs]@dataclassclassLeaveNOutSplitKeyContainer(SplitKeyContainer):"""A split key container that leaves one out of the dataset. By default, this creates a dataset that has one data point for validation, another for test, and the rest of the data for training. """delegate:Stash=field()"""The source for keys to generate the splits."""distribution:Dict[str,int]=field(default_factory=lambda:{'train':-1,'validation':1,'test':1})"""The number of data points by each split type. If the value is an integer, that number of data points are used. Otherwise, if it is a float, then that percentage of the entire key set is used. """shuffle:bool=field(default=True)"""If ``True``, shuffle the keys obtained from :obj:`delegate` before creating the splits. """path:Path=field(default=None)"""If not ``None``, persist the keys after shuffling (if enabled) to the path specified, for reproducibility of key partitions. """def__post_init__(self):path='_key_queue'ifself.pathisNoneelseself.pathself._key_queue=PersistedWork(path,self,mkdir=True)self._iter=0@persisted('_key_queue')def_get_key_queue(self)->deque:keys=list(self.delegate.keys())ifself.shuffle:rand.shuffle(keys)returndeque(keys)
[docs]defnext_split(self)->bool:"""Create the next split so that the next access to properties such as :obj:`keys_by_split` provide the next key split permutation. """key_queue=self._get_key_queue()key_queue.rotate(-1)self._iter+=1return(self._iter%len(key_queue))==0
@persisted('_split_names')def_get_split_names(self)->Set[str]:returnfrozenset(self.distribution.keys())def_get_keys_by_split(self)->Dict[str,Tuple[str]]:kq=self._get_key_queue()keys=iter(kq)klen=len(kq)ds=self.distribution.items()filled=Falseby_split={}forname,ninsorted(ds,key=lambdax:x[1],reverse=True):ifn<0:iffilled:raiseDatasetError("Distribution has more than one "+f"'fill' (-1) value: {ds}")filled=Truen=sys.maxsizeelifisinstance(n,float):n=int(n*klen)by_split[name]=tuple(it.islice(keys,n))total=sum(map(lambdax:len(x),by_split.values()))iftotal!=klen:raiseDatasetError(f'Number of allocated keys to the distribution ({total}) '+f'does not equal total keys ({klen})')returnby_split