Source code for zensols.deeplearn.batch.multi

"""Multi processing with torch.

"""
__author__ = 'Paul Landes'

from typing import Callable, List, Iterable, Any
from dataclasses import dataclass
import logging
from torch.multiprocessing import Pool as TorchPool
from zensols.util.time import time
from zensols.multi import MultiProcessStash

logger = logging.getLogger(__name__)


[docs] @dataclass class TorchMultiProcessStash(MultiProcessStash): """A multiprocessing stash that interacts with PyTorch in a way that it can access the GPU(s) in forked subprocesses using the :mod:`multiprocessing` library. :see: :mod:`torch.multiprocessing` :see: :meth:`zensols.deeplearn.TorchConfig.init` """ def _invoke_pool(self, pool: TorchPool, fn: Callable, data: iter) -> \ List[int]: """Invoke on a torch pool (rather than a :class:`multiprocessing.Pool`). """ if pool is None: return tuple(map(fn, data)) else: return pool.map(fn, data) def _invoke_work(self, workers: int, chunk_size: int, data: Iterable[Any]) -> int: fn = self.__class__._process_work if logger.isEnabledFor(logging.INFO): logger.info(f'{self.name}: spawning work with ' + f'chunk size {chunk_size} across {workers} workers') if workers == 1: with time('processed chunks'): cnt = self._invoke_pool(None, fn, data) else: with TorchPool(workers) as p: if logger.isEnabledFor(logging.INFO): logger.info(f'using torch multiproc pool: {p}') with time('processed chunks'): cnt = self._invoke_pool(p, fn, data) return cnt