Source code for zensols.deeplearn.batch.meta

"""Contains container classes for batch data.

"""
__author__ = 'Paul Landes'

from typing import Dict, Type
from dataclasses import dataclass
from dataclasses import field as dc_field
import sys
from io import TextIOBase
from zensols.config import Dictable
from zensols.deeplearn import NetworkSettings
from zensols.deeplearn.vectorize import FeatureVectorizer
from . import DataPoint, Batch, BatchFeatureMapping, FieldFeatureMapping


[docs] @dataclass class BatchFieldMetadata(Dictable): """Data that describes a field mapping in a batch object. """ field: FieldFeatureMapping = dc_field() """The field mapping.""" vectorizer: FeatureVectorizer = dc_field(repr=False) """The vectorizer used to map the field.""" @property def shape(self): return self.vectorizer.shape
[docs] def write(self, depth: int = 0, writer: TextIOBase = sys.stdout): self._write_line(self.field.attr, depth, writer) self._write_line('field:', depth + 1, writer) self.field.write(depth + 2, writer) self._write_line('vectorizer:', depth + 1, writer) self.vectorizer.write(depth + 2, writer)
[docs] @dataclass class BatchMetadata(Dictable): """Describes metadata about a :class:`.Batch` instance. """ data_point_class: Type[DataPoint] = dc_field() """The :class:`.DataPoint` class, which are created at encoding time.""" batch_class: Type[Batch] = dc_field() """The :class:`.Batch` class, which are created at encoding time.""" mapping: BatchFeatureMapping = dc_field() """The mapping used for encoding and decoding the batch.""" fields_by_attribute: Dict[str, BatchFieldMetadata] = dc_field(repr=False) """Mapping by field name to attribute."""
[docs] def write(self, depth: int = 0, writer: TextIOBase = sys.stdout): self._write_line(f'data point: {self.data_point_class}', depth, writer) self._write_line(f'batch: {self.batch_class}', depth, writer) self._write_line('mapping:', depth, writer) self.mapping.write(depth + 1, writer) self._write_line('attributes:', depth, writer) for attr, field in self.fields_by_attribute.items(): field.write(depth + 1, writer)
[docs] @dataclass class MetadataNetworkSettings(NetworkSettings): """A network settings container that has metadata about batches it recieves for its model. """ _PERSITABLE_TRANSIENT_ATTRIBUTES = {'batch_stash'} batch_stash: 'BatchStash' = dc_field(repr=False) """The batch stash that created the batches and has the batch metdata. """ @property def batch_metadata(self) -> BatchMetadata: """Return the batch metadata used by this model. """ return self.batch_stash.batch_metadata def _from_dictable(self, *args, **kwargs): dct = super()._from_dictable(*args, **kwargs) del dct['batch_stash'] return dct