Source code for zensols.deeplearn.layer.crf

"""Conditional random field PyTorch module forked from Kemal Kurniawan's
``pytorch_crf`` GitHub repository.  See the ``Torch CRF`` section of the
``README.md`` module documentation for more information.

:see: `pytorch_crf <https://github.com/kmkurn/pytorch-crf>`_
:see: `Torch CRF Readme <https://github.com/plandes/deeplearn#torch-crf>`_

"""
__author__ = 'Kemal Kurniawan, Paul Landes'
__version__ = '0.7.5'

from typing import List, Optional, Union, Tuple

import torch
import torch.nn as nn

from . import LayerError


[docs] class CRF(nn.Module): """Conditional random field. This module implements a conditional random field [LMP01]_. The forward computation of this class computes the log likelihood of the given sequence of tags and emission score tensor. This class also has `~CRF.decode` method which finds the best tag sequence given an emission score tensor using `Viterbi algorithm`_. Args: num_tags: Number of tags. batch_first: Whether the first dimension corresponds to the size of a minibatch. score_reduction: reduces how the score output over batches, and then tags, and has shape ``(batch size, number of tags)`` with the exception of ``tags``, which has shape ``(batch_size, sequence length, number of tags)``; how output is returned in :meth:`decode` by: - **skip**: do not return scores, only the decoded output (default) - **none**: return the scores unaltered, then divide by the batch count - **tags**: all scores - **sum**: sum the max over batches, then divide by the batch count - **max**: max over each batch max, then divide by the batch count - **min**: min over each batch max, then divide by the batch count - **mean**: average the max over batchs, then divide by the batch count Attributes: start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size ``(num_tags,)``. end_transitions (`~torch.nn.Parameter`): End transition score tensor of size ``(num_tags,)``. transitions (`~torch.nn.Parameter`): Transition score tensor of size ``(num_tags, num_tags)``. .. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001). "Conditional random fields: Probabilistic models for segmenting and labeling sequence data". *Proc. 18th International Conf. on Machine Learning*. Morgan Kaufmann. pp. 282–289. .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm """
[docs] def __init__(self, num_tags: int, batch_first: bool = False, score_reduction: str = 'skip') -> None: if num_tags <= 0: raise LayerError(f'Invalid number of tags: {num_tags}') super().__init__() self.num_tags = num_tags self.batch_first = batch_first self.score_reduction = score_reduction self.start_transitions = nn.Parameter(torch.empty(num_tags)) self.end_transitions = nn.Parameter(torch.empty(num_tags)) self.transitions = nn.Parameter(torch.empty(num_tags, num_tags)) self.reset_parameters()
[docs] def reset_parameters(self) -> None: """Initialize the transition parameters. The parameters will be initialized randomly from a uniform distribution between -0.1 and 0.1. """ nn.init.uniform_(self.start_transitions, -0.1, 0.1) nn.init.uniform_(self.end_transitions, -0.1, 0.1) nn.init.uniform_(self.transitions, -0.1, 0.1)
def __repr__(self) -> str: return f'{self.__class__.__name__}(num_tags={self.num_tags})'
[docs] def forward( self, emissions: torch.Tensor, tags: torch.LongTensor, mask: Optional[torch.ByteTensor] = None, reduction: str = 'sum', ) -> torch.Tensor: """Compute the conditional log likelihood of a sequence of tags given emission scores. Args: emissions (`~torch.Tensor`): Emission score tensor of size ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, ``(batch_size, seq_length, num_tags)`` otherwise. tags (`~torch.LongTensor`): Sequence of tags tensor of size ``(seq_length, batch_size)`` if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. reduction: Specifies the reduction to apply to the output: ``none|sum|mean|token_mean``. ``none``: no reduction will be applied. ``sum``: the output will be summed over batches. ``mean``: the output will be averaged over batches. ``token_mean``: the output will be averaged over tokens. Returns: `~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if reduction is ``none``, ``()`` otherwise. """ self._validate(emissions, tags=tags, mask=mask) if reduction not in ('none', 'sum', 'mean', 'token_mean'): raise LayerError(f'Invalid reduction: {reduction}') if mask is None: mask = torch.ones_like(tags, dtype=torch.uint8) if self.batch_first: emissions = emissions.transpose(0, 1) tags = tags.transpose(0, 1) mask = mask.transpose(0, 1) # shape: (batch_size,) numerator = self._compute_score(emissions, tags, mask) # shape: (batch_size,) denominator = self._compute_normalizer(emissions, mask) # shape: (batch_size,) llh = numerator - denominator if reduction == 'none': return llh if reduction == 'sum': return llh.sum() if reduction == 'mean': return llh.mean() assert reduction == 'token_mean' return llh.sum() / mask.type_as(emissions).sum()
[docs] def decode(self, emissions: torch.Tensor, mask: Optional[torch.ByteTensor] = None) -> \ Union[List[List[int]], Tuple[List[List[int]], torch.Tensor]]: """Find the most likely tag sequence using Viterbi algorithm. Args: emissions (`~torch.Tensor`): Emission score tensor of size ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, ``(batch_size, seq_length, num_tags)`` otherwise. mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. Returns: List of list containing the best tag sequence for each batch and optionally the scores based on the (~`score_reduction`) parameter in :meth:`__init__`. """ self._validate(emissions, mask=mask) if mask is None: mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8) if self.batch_first: emissions = emissions.transpose(0, 1) mask = mask.transpose(0, 1) return self._viterbi_decode(emissions, mask)
def _validate( self, emissions: torch.Tensor, tags: Optional[torch.LongTensor] = None, mask: Optional[torch.ByteTensor] = None) -> None: if emissions.dim() != 3: raise LayerError(f'Emissions must have dimension of 3, got {emissions.dim()}') if emissions.size(2) != self.num_tags: raise LayerError( f'Expected last dimension of emissions is {self.num_tags}, ' f'got {emissions.size(2)}') if tags is not None: if emissions.shape[:2] != tags.shape: raise LayerError( 'The first two dimensions of emissions and tags must match, ' f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}') if mask is not None: if emissions.shape[:2] != mask.shape: raise LayerError( 'The first two dimensions of emissions and mask must match, ' f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}') no_empty_seq = not self.batch_first and mask[0].all() no_empty_seq_bf = self.batch_first and mask[:, 0].all() if not no_empty_seq and not no_empty_seq_bf: raise LayerError('mask of the first timestep must all be on') def _compute_score( self, emissions: torch.Tensor, tags: torch.LongTensor, mask: torch.ByteTensor) -> torch.Tensor: # emissions: (seq_length, batch_size, num_tags) # tags: (seq_length, batch_size) # mask: (seq_length, batch_size) assert emissions.dim() == 3 and tags.dim() == 2 assert emissions.shape[:2] == tags.shape assert emissions.size(2) == self.num_tags assert mask.shape == tags.shape assert mask[0].all() seq_length, batch_size = tags.shape mask = mask.type_as(emissions) # Start transition score and first emission # shape: (batch_size,) score = self.start_transitions[tags[0]] score += emissions[0, torch.arange(batch_size), tags[0]] for i in range(1, seq_length): # Transition score to next tag, only added if next timestep is valid (mask == 1) # shape: (batch_size,) score += self.transitions[tags[i - 1], tags[i]] * mask[i] # Emission score for next tag, only added if next timestep is valid (mask == 1) # shape: (batch_size,) score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i] # End transition score # shape: (batch_size,) seq_ends = mask.long().sum(dim=0) - 1 # shape: (batch_size,) last_tags = tags[seq_ends, torch.arange(batch_size)] # shape: (batch_size,) score += self.end_transitions[last_tags] return score def _compute_normalizer( self, emissions: torch.Tensor, mask: torch.ByteTensor) -> torch.Tensor: # emissions: (seq_length, batch_size, num_tags) # mask: (seq_length, batch_size) assert emissions.dim() == 3 and mask.dim() == 2 assert emissions.shape[:2] == mask.shape assert emissions.size(2) == self.num_tags assert mask[0].all() seq_length = emissions.size(0) # Start transition score and first emission; score has size of # (batch_size, num_tags) where for each batch, the j-th column stores # the score that the first timestep has tag j # shape: (batch_size, num_tags) score = self.start_transitions + emissions[0] for i in range(1, seq_length): # Broadcast score for every possible next tag # shape: (batch_size, num_tags, 1) broadcast_score = score.unsqueeze(2) # Broadcast emission score for every possible current tag # shape: (batch_size, 1, num_tags) broadcast_emissions = emissions[i].unsqueeze(1) # Compute the score tensor of size (batch_size, num_tags, num_tags) where # for each sample, entry at row i and column j stores the sum of scores of all # possible tag sequences so far that end with transitioning from tag i to tag j # and emitting # shape: (batch_size, num_tags, num_tags) next_score = broadcast_score + self.transitions + broadcast_emissions # Sum over all possible current tags, but we're in score space, so a sum # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of # all possible tag sequences so far, that end in tag i # shape: (batch_size, num_tags) next_score = torch.logsumexp(next_score, dim=1) # Set score to the next score if this timestep is valid (mask == 1) # shape: (batch_size, num_tags) score = torch.where(mask[i].unsqueeze(1), next_score, score) # End transition score # shape: (batch_size, num_tags) score += self.end_transitions # Sum (log-sum-exp) over all possible tags # shape: (batch_size,) return torch.logsumexp(score, dim=1) def _viterbi_decode(self, emissions: torch.FloatTensor, mask: torch.ByteTensor) -> List[List[int]]: # emissions: (seq_length, batch_size, num_tags) # mask: (seq_length, batch_size) assert emissions.dim() == 3 and mask.dim() == 2 assert emissions.shape[:2] == mask.shape assert emissions.size(2) == self.num_tags assert mask[0].all() seq_length, batch_size = mask.shape # Start transition and first emission # shape: (batch_size, num_tags) score = self.start_transitions + emissions[0] history = [] # Keep the scores to later return if the user wants them them based on # the score reduction, which will have shape: # (batch_size, seq_len, num_tags) if self.score_reduction == 'tags': tag_scores = [] else: tag_scores = None # score is a tensor of size (batch_size, num_tags) where for every batch, # value at column j stores the score of the best tag sequence so far that ends # with tag j # history saves where the best tags candidate transitioned from; this is used # when we trace back the best tag sequence # Viterbi algorithm recursive case: we compute the score of the best tag sequence # for every possible next tag for i in range(1, seq_length): # Broadcast viterbi score for every possible next tag # shape: (batch_size, num_tags, 1) broadcast_score = score.unsqueeze(2) # Broadcast emission score for every possible current tag # shape: (batch_size, 1, num_tags) broadcast_emission = emissions[i].unsqueeze(1) # Compute the score tensor of size (batch_size, num_tags, num_tags) where # for each sample, entry at row i and column j stores the score of the best # tag sequence so far that ends with transitioning from tag i to tag j and emitting # shape: (batch_size, num_tags, num_tags) next_score = broadcast_score + self.transitions + broadcast_emission # Find the maximum score over all possible current tag # shape: (batch_size, num_tags) next_score, indices = next_score.max(dim=1) # Set score to the next score if this timestep is valid (mask == 1) # and save the index that produces the next score # shape: (batch_size, num_tags) score = torch.where(mask[i].unsqueeze(1), next_score, score) # Add scores to cat them later based on the score reduction # shape: (batch_size, num_tags) if tag_scores is not None: tag_scores.append(score.detach().unsqueeze(1)) history.append(indices) # End transition score # shape: (batch_size, num_tags) score += self.end_transitions # Add the final transition score to our returned scores # shape: (batch_size, num_tags) if tag_scores is not None: tag_scores.append(score.detach().unsqueeze(1)) # Now, compute the best path for each sample # shape: (batch_size,) seq_ends = mask.long().sum(dim=0) - 1 best_tags_list = [] for idx in range(batch_size): # Find the tag which maximizes the score at the last timestep; this is our best tag # for the last timestep _, best_last_tag = score[idx].max(dim=0) best_tags = [best_last_tag.item()] # We trace back where the best last tag comes from, append that to our best tag # sequence, and trace it back again, and so on for hist in reversed(history[:seq_ends[idx]]): best_last_tag = hist[idx][best_tags[-1]] best_tags.append(best_last_tag.item()) # Reverse the order because we start from the last timestep best_tags.reverse() best_tags_list.append(best_tags) # If the user wants scores, return them in the desired format if self.score_reduction == 'skip': return best_tags_list else: if self.score_reduction == 'tags': # shape: (batch_size, seq_length, num_tags) score = torch.cat(tag_scores, 1) elif self.score_reduction != 'none': score = score.max(dim=1)[0] score = {'sum': score.sum(), 'max': score.max(), 'min': score.min(), 'mean': score.mean()}[self.score_reduction] score = score / batch_size return best_tags_list, score