"""A utility class to help with a ``ModelExecutor`` life cycle.
"""
from __future__ import annotations
__author__ = 'Paul Landes'
from typing import Any, Dict, Tuple
from dataclasses import dataclass, field
import logging
from io import StringIO
from pathlib import Path
import torch
from zensols.util import time
from zensols.config import ConfigFactory, Configurable
from .. import ModelError, TorchConfig, NetworkSettings
from . import BaseNetworkModule
logger = logging.getLogger(__name__)
[docs]
@dataclass
class ModelManager(object):
"""This class manages the lifecycle of an instance of a ``ModelExecutor``.
This class is mostly used by the executor to control it's lifecycle.
However, a client can also use an instance of this to revive a model that's
been saved to disk with the ``ModelResultManager``
:see ModelExecutor:
"""
path: Path = field()
"""The path of where the model results saved to disk by
:class:`.zensols.deeplearn.results.ModelResultManager`.
"""
config_factory: ConfigFactory = field()
"""The configuration factory to be used to create the ``ModelExecutor``."""
model_executor_name: str = field(default=None)
"""The configuration entry and name of the ``ModelExecutor`` instance."""
persist_random_seed_context: bool = field(default=True)
"""If ``True`` persist the current random seed state, which helps in
creating consistent results across train/test/validate.
"""
keep_last_state_dict: bool = field(default=False)
"""Whether to store the PyTorch module state in attribute
``last_saved_state_dict``.
"""
@staticmethod
def _get_paths(path: Path) -> Tuple[Path, Path]:
return (path / 'state.pt', path / 'weight.pt')
[docs]
@classmethod
def load_from_path(cls, path: Path) -> ModelManager:
"""Load and return an instance of this class from a previously saved
model. This method exists to recreate a :class:`.ModelManager` from a
saved file from scratch. The returned model manager can be used to
create the executor or :class:`ModelFacade` using
:obj:``config_factory``.
:param path: points to the model file persisted with
:py:meth:`_save_executor`
:return: an instance of :class:`.ModelManager` that was used to save
the executor pointed by ``path``
"""
checkpoint = cls._load_checkpoint(*cls._get_paths(path))
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f'keys: {checkpoint.keys()}')
config_factory = checkpoint['config_factory']
model_executor_name = checkpoint['model_executor_name']
persist_random = checkpoint['random_seed_context'] is not None
return cls(path, config_factory, model_executor_name, persist_random)
[docs]
def load_executor(self, config_overwrites: Configurable = None) -> \
'ModelExecutor':
"""Load the model the last saved model from the disk. This is used load
an instance of a ``ModelExecutor`` with all previous state completely in
tact. It does this by using an instance of
:class:`zensols.config.factory.Configurable` and a
:class:`zensols.config.factory.ImportConfigFactory` to reconstruct the
executor and it's state by recreating all instances.
After the executor has been recreated with the factory, the previous
model results and model weights are restored.
:param load_factory: whether to load the configuration factory from the
check point; which you probably don't want when
loading from :meth:`load_from_path`
:return: an instance of :class:`.ModelExecutor`
:see: :class:`zensols.deeplearn.model.ModelExecutor`
"""
checkpoint: Dict[str, Any] = self._get_checkpoint(True)
# reload the config factory even if loaded from `load_from_path` since,
# in that case, this instance will be deallcated in the facade
config_factory: ConfigFactory = checkpoint['config_factory']
self._set_random_seed(checkpoint)
# overwrite model configuration before the executor is instantiated
if config_overwrites is not None:
config_overwrites.copy_sections(config_factory.config)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f'loading config factory: {config_factory}')
# create the executor from the executor section
executor: 'Executor' = config_factory.instance(
checkpoint['model_executor_name'])
# create the PyTorch model
model: BaseNetworkModule = self._create_module(executor.net_settings)
# load and set the state
self._load_optimizer_state(executor, model, checkpoint)
if 'model_result_report' in checkpoint:
executor._model_result_report = checkpoint['model_result_report']
return executor
def _load_optimizer_state(self, executor: Any, model: BaseNetworkModule,
checkpoint: Dict[str, Any]):
model.load_state_dict(checkpoint['model_state_dict'])
executor._set_model(model, True, False)
executor.model_result = checkpoint['model_result']
criterion, optimizer, scheduler = executor.criterion_optimizer_scheduler
if 'model_scheduler_state_dict' in checkpoint:
scheduler_state = checkpoint['model_scheduler_state_dict']
else:
scheduler_state = None
optimizer.load_state_dict(checkpoint['model_optim_state_dict'])
if scheduler is not None and scheduler_state is not None:
scheduler.load_state_dict(scheduler_state)
if logger.isEnabledFor(logging.INFO):
logger.info(f'loaded model from {executor.model_settings.path} ' +
f'on device {model.device}')
def _load_model_optim_weights(self, executor):
"""Load the model and optimizer weights from the last check point. A
side effect is that the optimizer is recreated.
"""
model = executor._get_or_create_model()
checkpoint = self._get_checkpoint(True)
self._load_optimizer_state(executor, model, checkpoint)
def _save_executor(self, executor: Any, save_model_result: bool):
"""Save a ``ModelExecutor`` instance.
:param executor: the executor to persost to disk
"""
logger.debug('saving model state')
if self.persist_random_seed_context:
random_seed_context = TorchConfig.get_random_seed_context()
else:
random_seed_context = None
criterion, optimizer, scheduler = executor.criterion_optimizer_scheduler
if scheduler is None:
scheduler_state = None
else:
scheduler_state = scheduler.state_dict()
state_dict = executor.model.state_dict()
if self.keep_last_state_dict:
self.last_saved_state_dict = self._copy_state_dict(state_dict)
if save_model_result:
model_result = executor.model_result
else:
model_result = None
checkpoint = {'config_factory': self.config_factory,
'random_seed_context': random_seed_context,
'model_executor_name': self.model_executor_name,
'net_settings_name': executor.net_settings.name,
'model_result': model_result,
'model_optim_state_dict': optimizer.state_dict(),
'model_scheduler_state_dict': scheduler_state,
'model_state_dict': state_dict}
if model_result is None and executor.model_settings.store_report:
sio = StringIO()
executor.model_result.write(writer=sio)
checkpoint['model_result_report'] = sio.getvalue().strip()
self._save_checkpoint(checkpoint, True)
def _create_module(self, net_settings: NetworkSettings,
reload: bool = False) -> BaseNetworkModule:
"""Create a new instance of the network model.
"""
resolver = self.config_factory.class_resolver
initial_reload = resolver.reload
try:
resolver.reload = reload
return net_settings.create_module()
finally:
resolver.reload = initial_reload
@staticmethod
def _copy_state_dict(state_dict):
"""Copy the PyTorch module state (weights) and return them as a dict.
"""
return {k: state_dict[k].clone() for k in state_dict.keys()}
def _set_random_seed(self, checkpoint: Dict[str, Any]):
random_seed_context = checkpoint['random_seed_context']
if random_seed_context is not None:
TorchConfig.set_random_seed(**random_seed_context)
def _save_final_trained_results(self, executor):
"""Save the results of the :class:`.ModelResult`, which is typically
called when the validation loss decreases. Note this does not save the
model weights since doing so might clobber with an overtrained model
(assuming the last converved with the lowest validation loss was saved).
Instead it only saves the model results.
:param executor: the executor with the model results to save
"""
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f'updating results: {self.path}')
checkpoint = self._get_checkpoint(False)
checkpoint['model_result'] = executor.model_result
self._save_checkpoint(checkpoint, False)
def _save_checkpoint(self, checkpoint: Dict[str, Any], save_weights: bool):
"""Save the check point to disk.
:param checkpoint: all model state (results, random seed, weights etc)
:param save_weights: if ``True`` then save the weights to the weight
file (in addition to the state to the state file)
"""
state_path, weight_path = self._get_paths(self.path)
weights = {}
for k in 'model_optim_state_dict model_state_dict'.split():
wval = checkpoint.pop(k, None)
if save_weights and wval is None:
raise ModelError(
f'Missing checkpoint key while saving weights: {k}')
weights[k] = wval
self.path.mkdir(parents=True, exist_ok=True)
if save_weights:
with time(f'saved model weights to {weight_path}'):
torch.save(weights, str(weight_path))
with time(f'saved model state to {state_path}'):
torch.save(checkpoint, str(state_path))
def _get_checkpoint(self, load_weights: bool) -> Dict[str, Any]:
"""The check point from loaded by the PyTorch framework. This contains
the executor, model results, and model weights.
:param load_weights: if ``True`` load the weights from the weights file
and add it to the checkpoint state
"""
state_path, weight_path = self._get_paths(self.path)
if not load_weights:
weight_path = None
return self._load_checkpoint(state_path, weight_path)
@staticmethod
def _load_checkpoint(state_path: Path, weight_path: Path) -> \
Dict[str, Any]:
if not state_path.is_file():
raise ModelError(f'No such state file: {state_path}')
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f'loading check point from: {state_path}')
with time(f'loaded check point from {state_path}'):
cp = torch.load(str(state_path))
if weight_path is not None:
params = {}
if not torch.cuda.is_available():
params['map_location'] = torch.device('cpu')
weights = torch.load(str(weight_path), **params)
cp.update(weights)
return cp