Source code for vak.nn.loss.crossentropy

import torch


[docs] class CrossEntropyLoss(torch.nn.CrossEntropyLoss): """Wrapper around :class:`torch.nn.CrossEntropyLoss` Converts the argument ``weight`` to a :class:`torch.Tensor` if it is a :class:`list`. """
[docs] def __init__( self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean", label_smoothing=0.0, ): if weight is not None: if isinstance(weight, torch.Tensor): pass elif isinstance(weight, list): weight = torch.Tensor(weight) super().__init__( weight, size_average, ignore_index, reduce, reduction, label_smoothing, )