Source code for vak.nn.functional

from typing import Optional

import torch

__all__ = ["one_hot"]


# adapted from kornia, https://github.com/kornia/kornia/blob/master/kornia/utils/one_hot.py
[docs] def one_hot( labels: torch.Tensor, num_classes: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, eps: float = 1e-6, ) -> torch.Tensor: r"""Converts an integer label x-D tensor to a one-hot (x+1)-D tensor. Args: labels (torch.Tensor) : tensor with labels of shape :math:`(N, *)`, where N is batch size. Each value is an integer representing correct classification. num_classes (int): number of classes in labels. device (Optional[torch.device]): the desired device of returned tensor. Default: if None, defaults to device of `labels`. dtype (Optional[torch.dtype]): the desired data type of returned tensor. Default: if None, infers data type from values. Returns: torch.Tensor: the labels in one hot tensor of shape :math:`(N, C, *)`, Examples: >>> labels = torch.LongTensor([[[0, 1], [2, 0]]]) >>> one_hot(labels, num_classes=3) tensor([[[[1.0000e+00, 1.0000e-06], [1.0000e-06, 1.0000e+00]], <BLANKLINE> [[1.0000e-06, 1.0000e+00], [1.0000e-06, 1.0000e-06]], <BLANKLINE> [[1.0000e-06, 1.0000e-06], [1.0000e+00, 1.0000e-06]]]]) """ if not isinstance(labels, torch.Tensor): raise TypeError( "Input labels type is not a torch.Tensor. Got {}".format( type(labels) ) ) if not labels.dtype == torch.int64: raise ValueError( "labels must be of the same dtype torch.int64. Got: {}".format( labels.dtype ) ) if num_classes < 1: raise ValueError( "The number of classes must be bigger than one." " Got: {}".format(num_classes) ) shape = labels.shape if device is None: device = labels.device one_hot = torch.zeros( (shape[0], num_classes) + shape[1:], device=device, dtype=dtype ) return one_hot.scatter_(1, labels.unsqueeze(1), 1.0) + eps