Source code for zensols.deeplearn.vectorize.util
"""Utiliies for encoding and decoding tensors.
"""
__author__ = 'Paul Landes'
from typing import Tuple, List, Sequence
from dataclasses import dataclass
from itertools import chain
from functools import reduce
import torch
from torch import Tensor
from . import TorchConfig
[docs]
@dataclass
class NonUniformDimensionEncoder(object):
"""Encode a sequence of tensors, each of arbitrary dimensionality, as a 1-D
array. Then decode the 1-D array back to the original.
"""
torch_config: TorchConfig
[docs]
def encode(self, arrs: Sequence[Tensor]) -> Tensor:
"""Encode a sequence of tensors, each of arbitrary dimensionality, as a
1-D array.
"""
def map_tensor_meta(arr: Tensor) -> Tuple[int]:
sz = arr.shape
tm = [len(sz)]
tm.extend(sz)
return tm
tmeta = [len(arrs)]
tmeta.extend(chain.from_iterable(map(map_tensor_meta, arrs)))
tmeta = self.torch_config.singleton(tmeta, dtype=arrs[0].dtype)
arrs = [tmeta] + list(map(lambda t: t.flatten(), arrs))
enc = torch.cat(arrs)
return enc
[docs]
def decode(self, arr: Tensor) -> Tuple[Tensor]:
"""Decode the 1-D array back to the original.
"""
ix_type = torch.long
shapes_len = arr[0].type(ix_type)
one = torch.tensor([1], dtype=ix_type, device=arr.device)
one.autograd = False
start = one.clone()
shapes: List[int] = []
for i in range(shapes_len):
sz_len = arr[start].type(ix_type)
start += one
end = (start + sz_len).type(ix_type)
sz = arr[start:end]
shapes.append(sz)
start = end
arrs = []
for shape in shapes:
ln = reduce(lambda x, y: x.type(ix_type) * y.type(ix_type), shape)
end = (start + ln).type(ix_type)
shape = tuple(map(int, shape))
x = arr[start:end]
x = x.view(shape)
arrs.append(x)
start = end
return tuple(arrs)