Source code for zensols.deeplearn.model.batchiter

"""Contains a class to assist in the batch loop during training, validation and
testing.

"""
__author__ = 'Paul Landes'

from typing import Any, Tuple, Dict
from dataclasses import dataclass, InitVar, field
import logging
from logging import Logger
from torch import Tensor
from torch import nn
from torch.optim import Optimizer
from zensols.deeplearn import ModelError, EarlyBailError, DatasetSplitType
from zensols.deeplearn.result import EpochResult
from zensols.deeplearn.batch import Batch, MetadataNetworkSettings
from . import BaseNetworkModule


[docs] @dataclass class BatchIterator(object): """This class assists in the batch loop during training, validation and testing. Any special handling of a model related to its loss function can be overridden in this class. .. document private functions .. automethod:: _decode_outcomes """ executor: InitVar[Any] = field() """The owning executor.""" logger: Logger = field() """The status logger from the executor.""" def __post_init__(self, executor: Any): self.model_settings = executor.model_settings self.net_settings = executor.net_settings self.torch_config = executor.torch_config
[docs] def _decode_outcomes(self, outcomes: Tensor) -> Tensor: """Transform the model output (and optionally the labels) that will be added to the ``EpochResult``, which composes a ``ModelResult``. This implementation returns :py:meth:~`Tensor.argmax`, which are the indexes of the max value across columns. """ logger = self.logger reduce_outcomes = self.model_settings.reduce_outcomes # get the indexes of the max value across labels and outcomes (for the # descrete classification case) if reduce_outcomes == 'argmax': res = outcomes.argmax(dim=-1) # softmax over each outcome elif reduce_outcomes == 'softmax': res = outcomes.softmax(dim=-1) elif reduce_outcomes == 'none': # leave when nothing, prediction/regression measure is used res = outcomes if logger.isEnabledFor(logging.DEBUG): logger.debug(f'argmax outcomes: {outcomes.shape} -> {res.shape}') return res
def _encode_labels(self, labels: Tensor) -> Tensor: """Encode labels to be in the same form and on the same CUDA device as the batch data. This base class implementation only copies to the GPU. :param labels: labels paired with the training and validation datasets :return: labels to be used in the loss function """ logger = self.logger if not self.model_settings.nominal_labels: if logger.isEnabledFor(logging.DEBUG): logger.debug(f'labels type: {labels.dtype}') labels = self.torch_config.to_type(labels) return labels def _debug_output(self, msg: str, labels: Tensor, output: Tensor): logger = self.logger if isinstance(self.debug, int) and self.debug > 1 and \ logger.isEnabledFor(logging.DEBUG): logger.debug(f'{msg}:') shape = None if labels is None else labels.shape dtype = None if labels is None else labels.dtype logger.debug(f'labels: {shape} ({dtype})') if isinstance(self.debug, int) and self.debug > 1: logger.debug(f'label values:\n{labels}') if output is None: logger.debug('output: <none>') else: logger.debug(f'output: {output.shape} ({output.dtype})') if isinstance(self.debug, int) and self.debug > 1: logger.debug(f'\n{output}')
[docs] def iterate(self, model: BaseNetworkModule, optimizer: Optimizer, criterion, batch: Batch, epoch_result: EpochResult, split_type: DatasetSplitType) -> Tensor: """Train, validate or test on a batch. This uses the back propogation algorithm on training and does a simple feed forward on validation and testing. One call of this method represents a single batch iteration :param model: the model to excercise :param optimizer: the optimization algorithm (i.e. adam) to iterate :param criterion: the loss function (i.e. cross entropy loss) used for the backward propogation step :param batch: contains the data to test, predict, and optionally the labels for training and validation :param epoch_result: to be populated with the results of this epoch's run :param split_type: indicates if we're training, validating or testing :return: the singleton tensor containing the loss """ logger = self.logger if logger.isEnabledFor(logging.DEBUG): logger.debug(f'train/validate on {split_type}: ' + f'batch={batch} ({id(batch)})') logger.debug(f'model on device: {model.device}') # copy batch to GPU if configured to do so batch: Batch = batch.to() outcomes: Tensor = None output: Tensor = None try: if self.debug: # write a batch sample when debugging; maybe make this a hook if isinstance(self.net_settings, MetadataNetworkSettings): meta = self.net_settings.batch_metadata meta.write() batch.write() # when training, reset gradients for the next epoch if split_type == DatasetSplitType.train: optimizer.zero_grad() # execute an the epoch loss, labels, outcomes, output = self._execute( model, optimizer, criterion, batch, split_type) self._debug_output('decode', labels, outcomes) # if debugging the model, raise the exception to interrupt the # flow, which is caught in ModelExecutor._execute if self.debug: raise EarlyBailError() if logger.isEnabledFor(logging.DEBUG): logger.debug('outcomes shape: {outcomes.shape}') # add results for performance metrics, predictions output, etc epoch_result.update(batch, loss, labels, outcomes, output) return loss finally: # clean up and GPU memeory deallocation biter = self.model_settings.batch_iteration cb = self.model_settings.cache_batches if (biter == 'cpu' and not cb) or biter == 'buffered': if logger.isEnabledFor(logging.DEBUG): logger.debug(f'deallocating batch: {batch}') batch.deallocate()
def _step(self, loss: Tensor, split_type: DatasetSplitType, optimizer, model: BaseNetworkModule): """Iterate over the error surface.""" # when training, backpropogate and step if split_type == DatasetSplitType.train: clip_thresh: float = self.model_settings.clip_gradient_threshold clip_params: Dict[str, Any] = \ self.model_settings.scale_gradient_params # invoke back propogation on the network loss.backward() # clip the gradient if clip_thresh is not None: nn.utils.clip_grad_value_(model.parameters(), clip_thresh) # scale the gradient if clip_params is not None: nn.utils.clip_grad_norm_(model.parameters(), **clip_params) # take an update step and update the new weights optimizer.step() def _execute(self, model: BaseNetworkModule, optimizer: Optimizer, criterion, batch: Batch, split_type: DatasetSplitType) -> \ Tuple[Tensor]: """Execute one epoch of training, testing, validation or prediction. :param model: the model to excercise :param optimizer: the optimization algorithm (i.e. adam) to iterate :param criterion: the loss function (i.e. cross entropy loss) used for the backward propogation step :param batch: contains the data to test, predict, and optionally the labels for training and validation :param split_type: indicates if we're training, validating or testing :return: a tuple of the loss, labels, outcomes, and the output (i.e. logits); the outcomes are the decoded (:meth:`_decode_outcomes`) output and represent some ready to use data, like argmax'd classification nominal label integers """ logger = self.logger labels: Tensor = batch.get_labels() # forward pass, get our output, which are usually the logits output: Tensor = model(batch) # sanity check if output is None: raise ModelError('Null model output') # check for sane state with labels, and munge if necessary if labels is None: # sanity check if split_type != DatasetSplitType.test: raise ModelError('Expecting no split type on prediction, ' + f'but got: {split_type}') if logger.isEnabledFor(logging.DEBUG): logger.debug('skipping loss calculation on prediction execute') loss = None else: # put labels in a form to be used by the loss function labels = self._encode_labels(labels) self._debug_output('input', labels, output) # calculate the loss with the logps and the labels loss = criterion(output, labels) if logger.isEnabledFor(logging.DEBUG): logger.debug(f'split: {split_type}, loss: {loss}') # iterate over the error surface self._step(loss, split_type, optimizer, model) self._debug_output('output', labels, output) # apply the same decoding on the labels as the output if necessary if labels is not None and not self.model_settings.nominal_labels: labels = self._decode_outcomes(labels) if logger.isEnabledFor(logging.DEBUG): logger.debug(f'label nom decoded: {labels.shape}') outcomes = self._decode_outcomes(output) loss, labels, outcomes, output = self.torch_config.to_cpu_deallocate( loss, labels, outcomes, output) return loss, labels, outcomes, output