Source code for zensols.deeplearn.dataframe.batch

"""An implementation of batch level API for Pandas dataframe based data.

"""
__author__ = 'Paul Landes'

import logging
from typing import Tuple
from dataclasses import dataclass, InitVar
import pandas as pd
import torch
from zensols.deeplearn.batch import (
    BatchError,
    BatchFeatureMapping,
    BatchStash,
    DataPoint,
    Batch,
)
from zensols.deeplearn.dataframe import DataframeFeatureVectorizerManager

logger = logging.getLogger(__name__)


[docs] @dataclass class DataframeBatchStash(BatchStash): """A stash used for batches of data using :class:`.DataframeBatch` instances. This stash uses an instance of :class:`.DataframeFeatureVectorizerManager` to vectorize the data in the batches. """ @property def feature_vectorizer_manager(self) -> DataframeFeatureVectorizerManager: managers = tuple(self.vectorizer_manager_set.values()) if len(managers) != 1: raise BatchError('Exected only one vector manager but got: ' + tuple(self.vectorizer_manager_set.keys())) vec_mng = managers[0] if not isinstance(vec_mng, DataframeFeatureVectorizerManager): raise BatchError( 'Expected class of type DataframeFeatureVectorizerManager ' + f'but got {vec_mng.__class__}') return vec_mng @property def label_shape(self) -> Tuple[int]: return self.feature_vectorizer_manager.label_shape @property def flattened_features_shape(self) -> Tuple[int]: vec_mng = self.feature_vectorizer_manager return vec_mng.get_flattened_features_shape(self.decoded_attributes)
[docs] @dataclass class DataframeDataPoint(DataPoint): """A data point used in a batch, which contains a single row of data in the Pandas dataframe. When created, column is saved as an attribute in the instance. """ row: InitVar[pd.Series] def __post_init__(self, row: pd.Series): for name, val in row.items(): if logger.isEnabledFor(logging.DEBUG): logger.debug(f'setting attrib: {name}={val}') setattr(self, name, val)
[docs] @dataclass class DataframeBatch(Batch): """A batch of data that contains instances of :class:`.DataframeDataPoint`, each of which has the row data from the dataframe. """ def _get_batch_feature_mappings(self) -> BatchFeatureMapping: """Use the dataframe based vectorizer manager. """ df_vec_mng: DataframeFeatureVectorizerManager = \ self.batch_stash.feature_vectorizer_manager return df_vec_mng.batch_feature_mapping
[docs] def get_features(self) -> torch.Tensor: """A utility method to a tensor of all features of all columns in the datapoints. :return: a tensor of shape (batch size, feature size), where the *feaure size* is the number of all features vectorized; that is, a data instance for each row in the batch, is a flattened set of features that represent the respective row from the dataframe """ def magic_shape(name: str) -> torch.Tensor: """Return a tensor that has two dimenions of the data (the first always with size 1 since it is a row of data). """ arr = attrs[name] if len(arr.shape) == 1: arr = arr.unsqueeze(dim=1) return arr attrs = self.attributes label_attr = self._get_batch_feature_mappings().label_attribute_name attr_names = filter(lambda k: k != label_attr, attrs.keys()) feats = tuple(map(magic_shape, attr_names)) return torch.cat(feats, dim=1)