fromtypingimportOptionalimporttorch__all__=["one_hot"]# adapted from kornia, https://github.com/kornia/kornia/blob/master/kornia/utils/one_hot.py
[docs]defone_hot(labels:torch.Tensor,num_classes:int,device:Optional[torch.device]=None,dtype:Optional[torch.dtype]=None,eps:float=1e-6,)->torch.Tensor:r"""Converts an integer label x-D tensor to a one-hot (x+1)-D tensor. Args: labels (torch.Tensor) : tensor with labels of shape :math:`(N, *)`, where N is batch size. Each value is an integer representing correct classification. num_classes (int): number of classes in labels. device (Optional[torch.device]): the desired device of returned tensor. Default: if None, defaults to device of `labels`. dtype (Optional[torch.dtype]): the desired data type of returned tensor. Default: if None, infers data type from values. Returns: torch.Tensor: the labels in one hot tensor of shape :math:`(N, C, *)`, Examples: >>> labels = torch.LongTensor([[[0, 1], [2, 0]]]) >>> one_hot(labels, num_classes=3) tensor([[[[1.0000e+00, 1.0000e-06], [1.0000e-06, 1.0000e+00]], <BLANKLINE> [[1.0000e-06, 1.0000e+00], [1.0000e-06, 1.0000e-06]], <BLANKLINE> [[1.0000e-06, 1.0000e-06], [1.0000e+00, 1.0000e-06]]]]) """ifnotisinstance(labels,torch.Tensor):raiseTypeError("Input labels type is not a torch.Tensor. Got {}".format(type(labels)))ifnotlabels.dtype==torch.int64:raiseValueError("labels must be of the same dtype torch.int64. Got: {}".format(labels.dtype))ifnum_classes<1:raiseValueError("The number of classes must be bigger than one."" Got: {}".format(num_classes))shape=labels.shapeifdeviceisNone:device=labels.deviceone_hot=torch.zeros((shape[0],num_classes)+shape[1:],device=device,dtype=dtype)returnone_hot.scatter_(1,labels.unsqueeze(1),1.0)+eps