Source code for zensols.deeplearn.model.executor

"""This file contains the network model and data that holds the results.

"""
__author__ = 'Paul Landes'

from dataclasses import dataclass, field
from typing import List, Callable, Tuple, Dict, Any, Union, Optional
import sys
import gc
import logging
import itertools as it
from itertools import chain
from io import TextIOBase, StringIO
import random as rand
from pathlib import Path
import torch
from torch import nn
from tqdm import tqdm
from zensols.util import time
from zensols.config import Configurable, ConfigFactory, Writable, ClassResolver
from zensols.persist import (
    Deallocatable,
    persisted, PersistedWork, PersistableContainer,
    Stash, UnionStash,
)
from zensols.dataset import SplitStashContainer, DatasetSplitStash
from zensols.deeplearn import (
    ModelError, EarlyBailError,
    TorchConfig, DatasetSplitType, NetworkSettings
)
from zensols.deeplearn.result import (
    EpochResult, ModelResult, ModelSettings, ModelResultManager,
)
from zensols.deeplearn.batch import BatchStash, Batch
from . import (
    ModelResourceFactory, BaseNetworkModule,
    ModelManager, UpdateAction,
    BatchIterator, TrainManager,
)

# default message logger
logger = logging.getLogger(__name__ + '.status')
# logger for messages, which is active when the progress bar is not
progress_logger = logging.getLogger(__name__ + '.progress')


[docs] @dataclass class ModelExecutor(PersistableContainer, Deallocatable, Writable): """This class creates and uses a network to train, validate and test the model. This class is either configured using a :class:`~zensols.config.factory.ConfigFactory` or is unpickled with :class:`.ModelManager`. If the later, it's from a previously trained (and possibly tested) state. Typically, after creating a nascent instance, :meth:`train` is called to train the model. This returns the results, but the results are also available via the :class:`ResultManager` using the :obj:`model_manager` property. To load previous results, use ``executor.result_manager.load()``. During training, the training set is used to train the weights of the model provided by the executor in the :obj:`model_settings`, then validated using the validation set. When the validation loss is minimized, the following is saved to disk: * Settings: :obj:`net_settings`, :obj:`model_settings`, * the model weights, * the results of the training and validation thus far, * the entire configuration (which is later used to restore the executor), * random seed information, which includes Python, Torch and GPU random state. After the model is trained, you can immediately test the model with :meth:`test`. To be more certain of being able to reproduce the same results, it is recommended to load the model with ``model_manager.load_executor()``, which loads the last instance of the model that produced a minimum validation loss. :see: :class:`.ModelExecutor` :see: :class:`.NetworkSettings` :see: :class:`zensols.deeplearn.model.ModelSettings` """ ATTR_EXP_META = ('model_settings',) config_factory: ConfigFactory = field() """The configuration factory that created this instance.""" config: Configurable = field() """The configuration used in the configuration factory to create this instance. """ name: str = field() """The name given in the configuration.""" model_settings: ModelSettings = field() """The configuration of the model.""" net_settings: NetworkSettings = field() """The settings used to configure the network.""" dataset_stash: DatasetSplitStash = field() """The split data set stash that contains the ``BatchStash``, which contains the batches on which to train and test. """ dataset_split_names: List[str] = field() """The list of split names in the ``dataset_stash`` in the order: train, validation, test (see :meth:`_get_dataset_splits`) """ result_path: Path = field(default=None) """If not ``None``, a path to a directory where the results are to be dumped; the directory will be created if it doesn't exist when the results are generated. """ update_path: Path = field(default=None) """The path to check for commands/updates to make while training. If this is set, and the file exists, then it is parsed as a JSON file. If the file cannot be parsed, or 0 size etc., then the training is (early) stopped. If the file can be parsed, and there is a single ``epoch`` dict entry, then the current epoch is set to that value. """ intermediate_results_path: Path = field(default=None) """If this is set, then save the model and results to this path after validation for each training epoch. """ progress_bar: bool = field(default=False) """Create text/ASCII based progress bar if ``True``.""" progress_bar_cols: int = field(default=None) """The number of console columns to use for the text/ASCII based progress bar. """ def __post_init__(self): super().__init__() if not isinstance(self.dataset_stash, DatasetSplitStash) and False: raise ModelError('Expecting type DatasetSplitStash but ' + f'got {self.dataset_stash.__class__}') self._model: BaseNetworkModule = None self._dealloc_model: bool = False self.model_result: ModelResult = None self.batch_stash.delegate_attr: bool = True self._criterion_optimizer_scheduler = PersistedWork( '_criterion_optimizer_scheduler', self) self._result_manager = PersistedWork('_result_manager', self) self._train_manager = PersistedWork('_train_manager', self) self.cached_batches: Dict[str, List[List[Batch]]] = {} self.debug: bool = False self._training_production: bool = None # set by ModelManager when available self._model_result_report: str = None @property def batch_stash(self) -> DatasetSplitStash: """The stash used to obtain the data for training and testing. This stash should have a training, validation and test splits. The names of these splits are given in the ``dataset_split_names``. """ return self.dataset_stash.split_container @property def feature_stash(self) -> Stash: """The stash used to generate the feature, which is not to be confused with the batch source stash``batch_stash``. """ return self.batch_stash.split_stash_container @property def torch_config(self) -> TorchConfig: """Return the PyTorch configuration used to convert models and data (usually GPU) during training and test. """ return self.batch_stash.model_torch_config @property @persisted('_result_manager') def result_manager(self) -> ModelResultManager: """Return the manager used for controlling the life cycle of the results generated by this executor. """ if self.result_path is not None: return self._create_result_manager(self.result_path) def _create_result_manager(self, path: Path) -> ModelResultManager: return ModelResultManager( name=self.model_settings.model_name, path=path, model_path=self.model_settings.path) @property def model_result_report(self) -> Optional[str]: """A human readable summary of the model's performance. :return: the report if the model has been trained """ if self.model_result is not None: sio = StringIO() self.model_result.write(writer=sio) return sio.getvalue().strip() else: return self._model_result_report @property @persisted('_model_manager') def model_manager(self) -> ModelManager: """Return the manager used for controlling the train of the model. """ model_path = self.model_settings.path return ModelManager(model_path, self.config_factory, self.name) @property @persisted('_batch_iterator') def batch_iterator(self) -> BatchIterator: """The train manager that assists with the training process. """ resolver: ClassResolver = self.config_factory.class_resolver batch_iter_class_name = self.model_settings.batch_iteration_class_name if logger.isEnabledFor(logging.DEBUG): logger.debug(f'batch_iteration: {batch_iter_class_name}') batch_iter_class = resolver.find_class(batch_iter_class_name) batch_iter = batch_iter_class(self, logger) if logger.isEnabledFor(logging.DEBUG): logger.debug(f'batch_iter={batch_iter}') return batch_iter @property def debug(self) -> Union[bool, int]: return self._debug @debug.setter def debug(self, debug: Union[bool, int]): self._debug = debug self.batch_iterator.debug = debug @property @persisted('_train_manager') def train_manager(self) -> TrainManager: """Return the train manager that assists with the training process.""" return TrainManager( logger, progress_logger, self.update_path, self.model_settings.max_consecutive_increased_count) def _weight_reset(self, m): if hasattr(m, 'reset_parameters') and callable(m.reset_parameters): if logger.isEnabledFor(logging.DEBUG): logger.debug(f'resetting parameters on {m}') m.reset_parameters()
[docs] def reset(self): """Reset the executor to it's nascent state.""" if logger.isEnabledFor(logging.INFO): logger.info('resetting executor') self._criterion_optimizer_scheduler.clear() self._deallocate_model()
[docs] def load(self) -> nn.Module: """Clear all results and trained state and reload the last trained model from the file system. :return: the model that was loaded and registered in this instance of the executor """ if logger.isEnabledFor(logging.INFO): logger.info('reloading model weights') self._deallocate_model() self.model_manager._load_model_optim_weights(self) if logger.isEnabledFor(logging.DEBUG): logger.debug(f'copied model to {self.model.device}') return self.model
[docs] def deallocate(self): super().deallocate() self._deallocate_model() self.deallocate_batches() self._try_deallocate(self.dataset_stash) self._deallocate_settings() self._criterion_optimizer_scheduler.deallocate() self._result_manager.deallocate() self.model_result = None
def _deallocate_model(self): if logger.isEnabledFor(logging.DEBUG): logger.debug('dealloc model: model exists/dealloc: ' + f'{self._model is not None}/{self._dealloc_model}') if self._model is not None and self._dealloc_model: self._try_deallocate(self._model) self._model = None def _deallocate_settings(self): self.model_settings.deallocate() self.net_settings.deallocate()
[docs] def deallocate_batches(self): set_of_ds_sets = self.cached_batches.values() ds_sets = chain.from_iterable(set_of_ds_sets) batches = chain.from_iterable(ds_sets) for batch in batches: batch.deallocate() self.cached_batches.clear()
@property def model_exists(self) -> bool: """Return whether the executor has a model. :return: ``True`` if the model has been trained or loaded """ return self._model is not None @property def model(self) -> BaseNetworkModule: """Get the PyTorch module that is used for training and test.""" if self._model is None: raise ModelError('No model, is populated; use \'load\'') return self._model @model.setter def model(self, model: BaseNetworkModule): """Set the PyTorch module that is used for training and test. """ self._set_model(model, False, True) def _set_model(self, model: BaseNetworkModule, take_owner: bool, deallocate: bool): if logger.isEnabledFor(level=logging.DEBUG): logger.debug(f'setting model: {type(model)}') if deallocate: self._deallocate_model() self._model = model self._dealloc_model = take_owner if logger.isEnabledFor(logging.DEBUG): logger.debug(f'setting dealloc model: {self._dealloc_model}') self._criterion_optimizer_scheduler.clear() def _get_or_create_model(self) -> BaseNetworkModule: if self._model is None: self._dealloc_model = True model = self._create_model() self._model = model else: model = self._model if logger.isEnabledFor(logging.DEBUG): logger.debug(f'created model as dealloc: {self._dealloc_model}') return model def _create_model(self) -> BaseNetworkModule: """Create the network model instance.""" mng: ModelManager = self.model_manager model = mng._create_module(self.net_settings, self.debug) if logger.isEnabledFor(logging.INFO): logger.info(f'created model on {model.device} ' + f'with {self.torch_config}') return model def _create_model_result(self) -> ModelResult: res = ModelResult( self.config, f'{self.model_settings.model_name}: {ModelResult.get_num_runs()}', self.model_settings, self.net_settings, self.batch_stash.decoded_attributes) if logger.isEnabledFor(logging.DEBUG): logger.debug(f'creating model result ({id(res)}): ' + self.model_settings.model_name) return res @property @persisted('_criterion_optimizer_scheduler') def criterion_optimizer_scheduler(self) -> \ Tuple[nn.L1Loss, torch.optim.Optimizer, Any]: """Return the loss function and descent optimizer.""" criterion = self._create_criterion() optimizer, scheduler = self._create_optimizer_scheduler() return criterion, optimizer, scheduler def _create_criterion(self) -> torch.optim.Optimizer: """Factory method to create the loss function and optimizer. """ resolver = self.config_factory.class_resolver criterion_class_name = self.model_settings.criterion_class_name if logger.isEnabledFor(logging.DEBUG): logger.debug(f'criterion: {criterion_class_name}') criterion_class = resolver.find_class(criterion_class_name) criterion = criterion_class() if logger.isEnabledFor(logging.DEBUG): logger.debug(f'criterion={criterion}') return criterion def _create_optimizer_scheduler(self) -> Tuple[nn.L1Loss, Any]: """Factory method to create the optimizer and the learning rate scheduler (is any). """ model = self.model resolver = self.config_factory.class_resolver optimizer_class_name = self.model_settings.optimizer_class_name if logger.isEnabledFor(logging.DEBUG): logger.debug(f'optimizer: {optimizer_class_name}') optimizer_class = resolver.find_class(optimizer_class_name) if self.model_settings.optimizer_params is None: optimizer_params = {} else: optimizer_params = dict(self.model_settings.optimizer_params) optimizer_params['lr'] = self.model_settings.learning_rate if issubclass(optimizer_class, ModelResourceFactory): opt_call = optimizer_class() optimizer_params['model'] = model optimizer_params['executor'] = self else: opt_call = optimizer_class optimizer = opt_call(model.parameters(), **optimizer_params) scheduler_class_name = self.model_settings.scheduler_class_name if scheduler_class_name is not None: scheduler_class = resolver.find_class(scheduler_class_name) scheduler_params = self.model_settings.scheduler_params if scheduler_params is None: scheduler_params = {} else: scheduler_params = dict(scheduler_params) scheduler_params['optimizer'] = optimizer if issubclass(scheduler_class, ModelResourceFactory): # model resource factories are callable sch_call = scheduler_class() scheduler_params['executor'] = self else: sch_call = scheduler_class scheduler = sch_call(**scheduler_params) else: scheduler = None if logger.isEnabledFor(logging.DEBUG): logger.debug(f'optimizer={optimizer}') return optimizer, scheduler
[docs] def get_model_parameter(self, name: str): """Return a parameter of the model, found in ``model_settings``. """ return getattr(self.model_settings, name)
[docs] def set_model_parameter(self, name: str, value: Any): """Safely set a parameter of the model, found in ``model_settings``. This makes the corresponding update in the configuration, so that when it is restored (i.e for test) the parameters are consistent with the trained model. The value is converted to a string as the configuration representation stores all data values as strings. *Important*: ``eval`` syntaxes are not supported, and probably not the kind of values you want to set a parameters with this interface anyway. :param name: the name of the value to set, which is the key in the configuration file :param value: the value to set on the model and the configuration """ self.config.set_option( name, str(value), section=self.model_settings.name) setattr(self.model_settings, name, value)
[docs] def get_network_parameter(self, name: str): """Return a parameter of the network, found in ``network_settings``.""" return getattr(self.net_settings, name)
[docs] def set_network_parameter(self, name: str, value: Any): """Safely set a parameter of the network, found in ``network_settings``. This makes the corresponding update in the configuration, so that when it is restored (i.e for test) the parameters are consistent with the trained network. The value is converted to a string as the configuration representation stores all data values as strings. *Important*: ``eval`` syntaxes are not supported, and probably not the kind of values you want to set a parameters with this interface anyway. :param name: the name of the value to set, which is the key in the configuration file :param value: the value to set on the network and the configuration """ self.config.set_option( name, str(value), section=self.net_settings.name) setattr(self.net_settings, name, value)
def _to_iter(self, ds): ds_iter = ds if isinstance(ds_iter, Stash): ds_iter = ds_iter.values() return ds_iter def _gc(self, level: int): """Invoke the Python garbage collector if ``level`` is high enough. The *lower* the value of ``level``, the more often it will be run during training, testing and validation. :param level: if priority of the need to collect--the lower the more its needed """ if level <= self.model_settings.gc_level: if logger.isEnabledFor(logging.DEBUG): logger.debug('garbage collecting') self._notify('gc_start') with time('garbage collected', logging.DEBUG): gc.collect() self._notify('gc_end') def _notify(self, event: str, context: Any = None): """Notify observers of events from this class. """ self.model_settings.observer_manager.notify(event, self, context) def _should_store_result(self) -> bool: mode: str = self.model_settings.store_model_result return ((mode == 'always') or (mode == 'train' and not self._training_production)) def _train(self, train: List[Batch], valid: List[Batch]): """Train the network model and record validation and training losses. Every time the validation loss shrinks, the model is saved to disk. """ store_mode: str = self.model_settings.store_model_result store_result: bool = \ ((store_mode == 'always') or (store_mode == 'test' and not self._training_production)) if logger.isEnabledFor(logging.DEBUG): logger.debug(f'store results: {store_result} ' + f'based on mode: {store_mode}') n_epochs: int = self.model_settings.epochs # create network model, loss and optimization functions model: BaseNetworkModule = self._get_or_create_model() model = self.torch_config.to(model) self._model = model if logger.isEnabledFor(logging.INFO): logger.info(f'training model {type(model)} on {model.device} ' + f'for {n_epochs} epochs using ' + f'learning rate {self.model_settings.learning_rate}') criterion, optimizer, scheduler = self.criterion_optimizer_scheduler # create a second module manager for after epoch results if self.intermediate_results_path is not None: model_path = self.intermediate_results_path intermediate_manager = self._create_result_manager(model_path) intermediate_manager.file_pattern = '{prefix}.{ext}' else: intermediate_manager = None train_manager = self.train_manager action = UpdateAction.ITERATE_EPOCH # set up graphical progress bar exec_logger = logging.getLogger(__name__) if self.progress_bar and \ (exec_logger.level == 0 or exec_logger.level > logging.INFO) and \ (progress_logger.level == 0 or progress_logger.level > logging.INFO): pbar = tqdm(total=n_epochs, ncols=self.progress_bar_cols) else: pbar = None train_manager.start(optimizer, scheduler, n_epochs, pbar) self.model_result.train.start() self.model_result.validation.start() # epochs loop while action != UpdateAction.STOP: epoch: int = train_manager.current_epoch train_epoch_result = EpochResult(epoch, DatasetSplitType.train) valid_epoch_result = EpochResult(epoch, DatasetSplitType.validation) if progress_logger.isEnabledFor(logging.INFO): progress_logger.debug(f'training on epoch: {epoch}') self.model_result.train.append(train_epoch_result) self.model_result.validation.append(valid_epoch_result) # train ---- # prep model for training and train model.train() train_epoch_result.start() self._notify('train_start', {'epoch': epoch}) for batch in self._to_iter(train): if logger.isEnabledFor(logging.DEBUG): logger.debug(f'training on batch: {batch.id}') with time('trained batch', level=logging.DEBUG): self.batch_iterator.iterate( model, optimizer, criterion, batch, train_epoch_result, DatasetSplitType.train) self._gc(3) self._notify('train_end', {'epoch': epoch}) train_epoch_result.end() self._gc(2) # validate ---- # prep model for evaluation and evaluate ave_valid_loss = 0 model.eval() valid_epoch_result.start() self._notify('validation_start', {'epoch': epoch}) for batch in self._to_iter(valid): # forward pass: compute predicted outputs by passing inputs # to the model with torch.no_grad(): loss = self.batch_iterator.iterate( model, optimizer, criterion, batch, valid_epoch_result, DatasetSplitType.validation) ave_valid_loss += (loss.item() * batch.size()) self._gc(3) self._notify('validation_end', {'epoch': epoch}) valid_epoch_result.end() ave_valid_loss = ave_valid_loss / len(valid) self._gc(2) valid_loss_min, decreased = train_manager.update_loss( valid_epoch_result, train_epoch_result, ave_valid_loss) if decreased: self.model_manager._save_executor(self, store_result) if intermediate_manager is not None: inter_res = self.model_result.get_intermediate() intermediate_manager.save_text_result(inter_res) intermediate_manager.save_plot_result(inter_res) # look for indication of update or early stopping status = train_manager.get_status() action = status.action val_losses = train_manager.validation_loss_decreases if logger.isEnabledFor(logging.INFO): logger.info('final minimum validation ' + f'loss: {train_manager.valid_loss_min}, ' + f'{val_losses} decreases') if val_losses == 0: logger.warn('no validation loss decreases encountered, ' + 'so there was no model saved; model can not be tested') self.model_result.train.end() self.model_result.validation.end() if store_result: # add updated training and validation results (not weights) self.model_manager._save_final_trained_results(self) def _test(self, batches: List[Batch]): """Test the model on the test set. If a model is not given, it is unpersisted from the file system. """ # create the loss and optimization functions criterion, optimizer, scheduler = self.criterion_optimizer_scheduler model = self.torch_config.to(self.model) # track epoch progress test_epoch_result = EpochResult(0, DatasetSplitType.test) if logger.isEnabledFor(logging.INFO): logger.info(f'testing model {type(model)} on {model.device}') # in for some reason the model was trained but not tested, we'll load # from the model file, which will have no train results (bad idea) if self.model_result is None: self.model_result = self._create_model_result() self.model_result.reset(DatasetSplitType.test) self.model_result.test.start() self.model_result.test.append(test_epoch_result) # prep model for evaluation model.eval() # run the model on test data test_epoch_result.start() for batch in self._to_iter(batches): # forward pass: compute predicted outputs by passing inputs # to the model with torch.no_grad(): self.batch_iterator.iterate( model, optimizer, criterion, batch, test_epoch_result, DatasetSplitType.test) self._gc(3) test_epoch_result.end() self._gc(2) self.model_result.test.end() def _preproces_training(self, ds_train: Tuple[Batch]): """Preprocess the training set, which for this method implementation, includes a shuffle if configured in the model settings. """ self._notify('preprocess_training_start') if self.model_settings.shuffle_training: if logger.isEnabledFor(logging.DEBUG): logger.debug('shuffling training dataset') # data sets are ordered with training as the first rand.shuffle(ds_train) self._notify('preprocess_training_end') def _calc_batch_limit(self, src: Stash, batch_limit: Union[int, float]) -> int: if batch_limit <= 0: raise ModelError(f'Batch limit must be positive: {batch_limit}') if isinstance(batch_limit, float): if batch_limit > 1.0: raise ModelError('Batch limit must be less than 1 ' + f'when a float: {batch_limit}') vlim = round(len(src) * batch_limit) if logger.isEnabledFor(logging.DEBUG): logger.debug('batch limit calculated as a percentage: ' + f'{vlim} = {len(src)} * {batch_limit}') else: vlim = batch_limit if isinstance(src, SplitStashContainer): desc = f' for {src.split_name}' else: desc = '' if logger.isEnabledFor(logging.INFO): logger.info(f'using batch limit: {vlim}{desc}') return vlim def _prepare_datasets(self, batch_limit: Union[int, float], to_deallocate: List[Batch], ds_src: List[Stash]) -> List[List[Batch]]: """Return batches for each data set. The batches are returned per dataset as given in :meth:`_get_dataset_splits`. Return: [(training batch 1..N), (validation batch 1..N), (test batch 1..N)] """ biter = self.model_settings.batch_iteration cnt = 0 if logger.isEnabledFor(logging.INFO): logger.info(f'preparing datasets using iteration: {biter}') self._notify('prepare_datasets_start', biter) if biter == 'gpu': ds_dst = [] for src in ds_src: vlim = self._calc_batch_limit(src, batch_limit) cpu_batches = tuple(it.islice(src.values(), vlim)) gpu_batches = list(map(lambda b: b.to(), cpu_batches)) cnt += len(gpu_batches) # the `to` call returns the same instance if the tensor is # already on the GPU, so only deallocate batches copied over for cpu_batch, gpu_batch in zip(cpu_batches, gpu_batches): if cpu_batch is not gpu_batch: to_deallocate.append(cpu_batch) if not self.model_settings.cache_batches: to_deallocate.extend(gpu_batches) ds_dst.append(gpu_batches) elif biter == 'cpu': ds_dst = [] for src in ds_src: vlim = self._calc_batch_limit(src, batch_limit) batches = list(it.islice(src.values(), vlim)) cnt += len(batches) if not self.model_settings.cache_batches: to_deallocate.extend(batches) ds_dst.append(batches) elif biter == 'buffered': ds_dst = ds_src cnt = '?' else: raise ModelError(f'No such batch iteration method: {biter}') self._notify('prepare_datasets_end', biter) self._preproces_training(ds_dst[0]) return cnt, ds_dst def _execute(self, sets_name: str, description: str, func: Callable, ds_src: tuple) -> bool: """Either train or test the model based on method ``func``. :param sets_name: the name of the data sets, which ``train`` or ``test`` :param func: the method to call to do the training or testing :param ds_src: a tuple of datasets in a form such as ``(train, validation, test)`` (see :meth:`_get_dataset_splits`) :return: ``True`` if training/testing was successful, otherwise `the an exception occured or early bail """ to_deallocate: List[Batch] = [] ds_dst: List[List[Batch]] = None batch_limit = self.model_settings.batch_limit biter = self.model_settings.batch_iteration if self.model_settings.cache_batches and biter == 'buffered': raise ModelError('Can not cache batches for batch ' + 'iteration setting \'buffered\'') if logger.isEnabledFor(logging.INFO): logger.info(f'batch iteration: {biter}, limit: {batch_limit}' + f', caching: {self.model_settings.cache_batches}' f', cached: {len(self.cached_batches)}') self._notify('execute_start', sets_name) self._gc(1) ds_dst = self.cached_batches.get(sets_name) if ds_dst is None: cnt = 0 with time('loaded {cnt} batches'): cnt, ds_dst = self._prepare_datasets( batch_limit, to_deallocate, ds_src) if self.model_settings.cache_batches: self.cached_batches[sets_name] = ds_dst if logger.isEnabledFor(logging.INFO): logger.info('train/validation sets: ' + f'{" ".join(map(lambda l: str(len(l)), ds_dst))}') try: with time(f'executed {sets_name}'): func(*ds_dst) if description is not None: res_name = f'{self.model_result.index}: {description}' self.model_result.name = res_name return True except EarlyBailError as e: logger.warning(f'<{e}>') self.reset() return False finally: self._notify('execute_end', sets_name) self._train_manager.clear() if logger.isEnabledFor(logging.INFO): logger.info(f'deallocating {len(to_deallocate)} batches') for batch in to_deallocate: if logger.isEnabledFor(logging.DEBUG): logger.debug(f'deallocating: {batch}') batch.deallocate() self._gc(1) self.torch_config.empty_cache() def _get_dataset_splits(self) -> Tuple[BatchStash]: """Return a stash, one for each respective data set tracked by this executor. """ def map_split(n: str) -> DatasetSplitStash: s = splits.get(n) if logger.isEnabledFor(logging.DEBUG): logger.debug(f'split: {n}={len(s)}') if s is None: raise ModelError( f"No split '{n}' in {self.dataset_stash.split_names}, " + f'executor splits: {self.dataset_split_names}') return s splits = self.dataset_stash.splits return tuple(map(map_split, self.dataset_split_names))
[docs] def train(self, description: str = None) -> ModelResult: """Train the model. """ self.model_result = self._create_model_result() train, valid, _ = self._get_dataset_splits() self._training_production = False self._execute('train', description, self._train, (train, valid)) if logger.isEnabledFor(logging.DEBUG): logger.debug(f'trained model result: {self.model_result}') return self.model_result
[docs] def test(self, description: str = None) -> ModelResult: """Test the model. """ train, valid, test = self._get_dataset_splits() if self.model_result is None: logger.warning('no results found--loading') self.model_result = self.result_manager.load() self._training_production = False self._execute('test', description, self._test, (test,)) if logger.isEnabledFor(logging.DEBUG): logger.debug(f'tested model result: {self.model_result}') return self.model_result
[docs] def train_production(self, description: str = None) -> ModelResult: """Train and test the model on the training and test datasets. This is used for a "production" model that is used for some purpose other than evaluation. """ self.model_result = self._create_model_result() train, valid, test = self._get_dataset_splits() train = UnionStash((train, test)) self._training_production = True self._execute('train production', description, self._train, (train, valid)) return self.model_result
[docs] def predict(self, batches: List[Batch]) -> ModelResult: """Create predictions on ad-hoc data. :param batches: contains the data (X) on which to predict :return: the results of the predictions """ for batch in batches: self.batch_stash.populate_batch_feature_mapping(batch) self._test(batches) return self.model_result.test
[docs] def write_model(self, depth: int = 0, writer: TextIOBase = sys.stdout): model = self._get_or_create_model() sio = StringIO() sp = self._sp(depth + 1) nl = '\n' print(model, file=sio) self._write_line('model:', depth, writer) writer.write(nl.join(map(lambda s: sp + s, sio.getvalue().split(nl))))
[docs] def write_settings(self, depth: int = 0, writer: TextIOBase = sys.stdout): self._write_line('network settings:', depth, writer) self._write_dict(self.net_settings.asdict(), depth + 1, writer) self._write_line('model settings:', depth, writer) self._write_dict(self.model_settings.asdict(), depth + 1, writer)
[docs] def write(self, depth: int = 0, writer: TextIOBase = sys.stdout, include_settings: bool = False, include_model: bool = False): sp = self._sp(depth) writer.write(f'{sp}model: {self.model_settings.model_name}\n') writer.write(f'{sp}feature splits:\n') self.feature_stash.write(depth + 1, writer) writer.write(f'{sp}batch splits:\n') self.dataset_stash.write(depth + 1, writer) if include_settings: self.write_settings(depth, writer) if include_model: self.write_model(depth, writer)