"""Multi processing with torch."""__author__='Paul Landes'fromtypingimportCallable,List,Iterable,Anyfromdataclassesimportdataclassimportloggingfromtorch.multiprocessingimportPoolasTorchPoolfromzensols.util.timeimporttimefromzensols.multiimportMultiProcessStashlogger=logging.getLogger(__name__)
[docs]@dataclassclassTorchMultiProcessStash(MultiProcessStash):"""A multiprocessing stash that interacts with PyTorch in a way that it can access the GPU(s) in forked subprocesses using the :mod:`multiprocessing` library. :see: :mod:`torch.multiprocessing` :see: :meth:`zensols.deeplearn.TorchConfig.init` """def_invoke_pool(self,pool:TorchPool,fn:Callable,data:iter)-> \
List[int]:"""Invoke on a torch pool (rather than a :class:`multiprocessing.Pool`). """ifpoolisNone:returntuple(map(fn,data))else:returnpool.map(fn,data)def_invoke_work(self,workers:int,chunk_size:int,data:Iterable[Any])->int:fn=self.__class__._process_workiflogger.isEnabledFor(logging.INFO):logger.info(f'{self.name}: spawning work with '+f'chunk size {chunk_size} across {workers} workers')ifworkers==1:withtime('processed chunks'):cnt=self._invoke_pool(None,fn,data)else:withTorchPool(workers)asp:iflogger.isEnabledFor(logging.INFO):logger.info(f'using torch multiproc pool: {p}')withtime('processed chunks'):cnt=self._invoke_pool(p,fn,data)returncnt