Source code for zensols.deeplearn.model.trainmng

"""Contains a class to assist in the training of the :class:`.ModelExecutor`.

"""
__author__ = 'Paul Landes'

from typing import Any, Tuple
from dataclasses import dataclass, field
from enum import Enum
import sys
import logging
from logging import Logger
import json
from pathlib import Path
from tqdm import tqdm
import numpy as np
import torch
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from zensols.deeplearn import ModelError
from zensols.deeplearn.result import EpochResult


[docs] class UpdateAction(Enum): """An action type to invoke on the :class:`.ModelExecutor` during training. """ ITERATE_EPOCH = 0 SET_EPOCH = 1 STOP = 2
[docs] @dataclass class TrainStatus(object): """Indicates what to do in the next epoch of the training cycle. """ action: UpdateAction epoch: int = field(default=None) reason: str = field(default=None)
[docs] @dataclass class TrainManager(object): """The class is used to assist in the training of the :class:`.ModelExecutor`. It updates validation loss and helps with early stopping decisions. It also watches for a file on the file system to provide instructions on what to do in the next epoch. """ status_logger: Logger = field() """The logger to record status updates during training.""" progress_logger: Logger = field() """The logger to record progress updates during training. This is used only when the progress bar is turned off (see :meth:`.ModelFacade._configure_cli_logging`). """ update_path: Path = field() """See :obj:`.ModelExecutor.update_path`. """ max_consecutive_increased_count: int = field() """See :obj:`.Domain.max_consecutive_increased_count`. """ progress_bar_number_width: int = field(default=6) """The string width of the train/validation loss metrics in the progress bar, which needs to be greater than 4. """
[docs] def start(self, optimizer: nn.L1Loss, scheduler: Any, n_epochs: int, pbar: tqdm): # clear any early stop state if self.update_path is not None and self.update_path.is_file(): self.progress_logger.info( f'cleaning update file: {self.update_path}') self.update_path.unlink() self.optimizer = optimizer self.scheduler = scheduler self.n_epochs = n_epochs self.current_epoch = 0 self.consecutive_increased_count = 0 # set initial "min" to infinity self.valid_loss_min = np.Inf self.pbar = pbar if self.progress_logger.isEnabledFor(logging.INFO): self.progress_logger.info( f'watching update file {self.update_path}') self.validation_loss_decreases = 0
def _get_optimizer_lr(self, optimizer: torch.optim.Optimizer) -> float: """Return the current optimizer learning rate, which can be modified by a scheduler if one is configured. """ param_group = next(iter(optimizer.param_groups)) return float(param_group['lr']) def _fixed_sci_format(self, v: str) -> str: """Format a number to a width resorting to scientific notation where necessary. The returned string is left padded with space in cases where scientific notation is too wide for ``v > 0``. The mantissa is cut off also for ``v > 0`` when the string version of the number is too wide. """ length: int = self.progress_bar_number_width n: int = length ln: int = None pad: int = None while n > 0: i = len('%#.*g' % (n, v)) s = '%.*g' % (n + n - i, v) ln = len(s) pad = length - ln if pad >= 0: break n -= 1 if pad > 0: s = (' ' * pad) + s return s
[docs] def update_loss(self, valid_epoch_result: EpochResult, train_epoch_result: EpochResult, ave_valid_loss: float) -> Tuple[float, bool]: """Update the training and validation loss. :return: a tuple of the latest minimum validation loss and whether or not the last validation loss has decreased """ logger = self.status_logger progress_logger = self.progress_logger optimizer = self.optimizer valid_loss = valid_epoch_result.ave_loss sfmt = self._fixed_sci_format # adjust the learning rate if a scheduler is configured if self.scheduler is not None: # the LambdaLR scheduler creates warnings when it gets the # validation loss if isinstance(self.scheduler, LambdaLR): self.scheduler.step() else: self.scheduler.step(ave_valid_loss) if logger.isEnabledFor(logging.DEBUG): logger.debug('epoch ave valid loss/results averaged valid_loss ' + f'{ave_valid_loss}/{valid_loss}, ' + f'losses: {len(valid_epoch_result.losses)}') decreased = valid_loss < self.valid_loss_min dec_str = '\\/' if decreased else '/\\' if abs(ave_valid_loss - valid_loss) > 1e-10: logger.warning('validation loss and result are not close: ' + f'{ave_valid_loss} - {valid_loss} > 1e-10') if train_epoch_result.contains_results: train_loss = train_epoch_result.ave_loss else: train_loss = -1 msg = (f'tr:{sfmt(train_loss)}|' + f'va min:{sfmt(self.valid_loss_min)}|' + f'va:{sfmt(valid_loss)}') if self.scheduler is not None: lr = self._get_optimizer_lr(optimizer) msg += f'|lr:{sfmt(lr)}' msg += f' {dec_str}' if self.pbar is not None: if logger.isEnabledFor(logging.DEBUG): logger.debug(msg) self.pbar.set_description(msg) else: if logger.isEnabledFor(logging.INFO): logger.info(f'epoch {self.current_epoch}/' + f'{self.n_epochs}: {msg}') # save model if validation loss has decreased if decreased: if progress_logger.isEnabledFor(logging.DEBUG): progress_logger.debug('validation loss decreased min/iter' + f'({self.valid_loss_min:.6f}' + f'/{valid_loss:.6f}); saving model') self.valid_loss_min = valid_loss self.consecutive_increased_count = 0 self.validation_loss_decreases += 1 else: if progress_logger.isEnabledFor(logging.DEBUG): progress_logger.debug('validation loss increased min/iter' + f'({self.valid_loss_min:.6f}' + f'/{valid_loss:.6f})') self.consecutive_increased_count += 1 return self.valid_loss_min, decreased
def _read_status(self) -> TrainStatus: """Read the early stop/update file and return a value to update the current epoch number (if any). """ update = TrainStatus(UpdateAction.ITERATE_EPOCH) update_path = self.update_path if update_path is not None: if self.status_logger.isEnabledFor(logging.DEBUG): self.status_logger.debug(f'update check at {update_path}') if update_path.exists(): data = None try: with open(update_path) as f: data = json.load(f) if 'epoch' in data: epoch = int(data['epoch']) update.epoch = epoch update.action = UpdateAction.SET_EPOCH update.reason = (f'update from {update_path}: ' + f'setting epoch to: {epoch}') except Exception as e: reason = f'bad format in {update_path}--assume exit: {e}' update.action = UpdateAction.STOP update.reason = reason update_path.unlink() return update def _get_stop_reason(self) -> str: reason = None if self.current_epoch >= self.n_epochs: reason = f'epoch threshold reached at {self.n_epochs}' elif (self.consecutive_increased_count > self.max_consecutive_increased_count): reason = ('reached max consecutive increased count: ' + f'{self.max_consecutive_increased_count}') return reason
[docs] def get_status(self) -> TrainStatus: """Return the epoch to set in the training loop of the :class:`.ModelExecutor`. """ status = self._read_status() if status.action == UpdateAction.STOP: # setting to the max value fails the executors train outter loop # causing a robust non-error exit status.epoch = sys.maxsize elif status.action == UpdateAction.SET_EPOCH: self.current_epoch = status.epoch if self.pbar is not None: self.pbar.reset() self.pbar.update(self.current_epoch) elif status.action == UpdateAction.ITERATE_EPOCH: self.current_epoch += 1 status.epoch = self.current_epoch stop_reason = self._get_stop_reason() if self.pbar is not None: self.pbar.update() if stop_reason is not None: status.action = UpdateAction.STOP status.reason = stop_reason else: raise ModelError(f'Unknownn status: {status}') if status.reason and self.status_logger.isEnabledFor(logging.INFO): self.status_logger.info(status.reason) return status
[docs] def stop(self) -> bool: """Stops the execution of training the model. Currently this is done by creating a file the executor monitors. :return: ``True`` if the application is configured to early stop and the signal has not already been given """ update_path = self.update_path if update_path is not None and not update_path.is_file(): update_path.parent.mkdir(parents=True, exist_ok=True) update_path.touch() self.status_logger.info(f'created early stop file: {update_path}') return True return False