PyTorch使用F.cross_entropy报错Assertion `t >= 0 && t < n_classes` failed问题记录

前言

PyTorch使用F.cross_entropy报错Assertion `t >= 0 && t < n_classes` failed问题记录_第1张图片

在PyTorch框架下使用F.cross_entropy()函数时,偶尔会报错ClassNLLCriterion ··· Assertion `t >= 0 && t < n_classes ` failed

错误信息类似下面打印信息:

/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/ClassNLLCriterion.cu:52: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [0,0,0] Assertion `t >= 0 && t < n_classes` failed.
/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/ClassNLLCriterion.cu:52: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [1,0,0] Assertion `t >= 0 && t < n_classes` failed.
/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/ClassNLLCriterion.cu:52: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [2,0,0] Assertion `t >= 0 && t < n_classes` failed.
/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/ClassNLLCriterion.cu:52: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [3,0,0] Assertion `t >= 0 && t < n_classes` failed.
THCudaCheck FAIL file=/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/generic/ClassNLLCriterion.cu line=83 error=59 : device-side assert triggered
Traceback (most recent call last):
  File "tutorial.py", line 100, in <module>
    model = train_model(model, criterion, optim_scheduler_ft, num_epochs=25)
  File "tutorial.py", line 80, in train_model
    loss = criterion(outputs, labels)
  File "python3.7/site-packages/torch/nn/modules/module.py", line 206, in __call__
    result = self.forward(*input, **kwargs)
  File "python3.7/site-packages/torch/nn/modules/loss.py", line 313, in forward
    self.weight, self.size_average)
  File "python3.7/site-packages/torch/nn/functional.py", line 509, in cross_entropy
    return nll_loss(log_softmax(input), target, weight, size_average)
  File "python3.7/site-packages/torch/nn/functional.py", line 477, in nll_loss
    return f(input, target)
  File "python3.7/site-packages/torch/nn/_functions/thnn/auto.py", line 41, in forward
    output, *self.additional_args)
RuntimeError: cuda runtime error (59) : device-side assert triggered at /py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/generic/ClassNLLCriterion.cu:83

通常情况下,这是由于求交叉熵函数在计算时遇到了类别错误的问题,即不满足t >= 0 && t < n_classes条件。

t >= 0 && t < n_classes条件

在分类任务中,需要调用torch.nn.functional.cross_entropy()函数求交叉熵,从PyTorch官网可以看到该函数定义:
PyTorch使用F.cross_entropy报错Assertion `t >= 0 && t < n_classes` failed问题记录_第2张图片

torch.nn.functional.cross_entropy(input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')

可以注意到有一个key-value是ignore_index=-100。这是在交叉熵计算时被跳过的部分。通常是在数据增强中的填充值。

而在代码运行中报错ClassNLLCriterion Assertion `t >= 0 && t < n_classes ` failed,大部分都是由于没有正确处理好label(ground truth)导致的。例如在数据增强中,填充数据使用了负数,或者使用了某大正数(如255),而在调用torch.nn.functional.cross_entropy()方法时却没有传入正确的ignore_index。这就会导致运行过程中的Assertion Error。

PyTorch使用F.cross_entropy报错Assertion `t >= 0 && t < n_classes` failed问题记录_第3张图片

代码示例

数据增强部分

import torchvision.transforms.functional as tf

tf.pad(cropped_img, padding_tuple, padding_mode="reflect"),
tf.affine(mask, translate=(-x_offset, -y_offset), scale=1.0, angle=0.0, shear=0.0,fillcolor=250,)

求交叉熵部分

import torch
import torch.nn.functional as F
import torch.nn as nn


def cross_entropy2d(input, target, weight=None, reduction='none'):
    n, c, h, w = input.size()
    nt, ht, wt = target.size()
    
    if h != ht or w != wt:
        input = F.interpolate(input, size=(
            ht, wt), mode="bilinear", align_corners=True)

    input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
    target = target.view(-1)
    loss = F.cross_entropy(input, target, weight=weight, reduction=reduction, ignore_index=255)

    return loss

分析

可以看到在数据增强时的填充值为250(fillcolor=250),但在求交叉熵时却传入了ignore_index=255。因此在代码运行时,F.cross_entropy部分便会报错ClassNLLCriterion ··· Assertion `t >= 0 && t < n_classes ` failed。只需要统一好label部分填充数据和计算交叉熵时需要忽略的class就可以避免出现这一问题。

其他

在PyTorch框架下,使用无用label值进行填充和处理时,要注意在使用scatter_函数时也需要注意对无用label进行提前处理,否则在使用data.scatter_()时同样也会报类似类别index错误。

labels = labels[:, :, :].view(size[0], 1, size[1], size[2])
oneHot_size = (size[0], classes, size[1], size[2])
labels_real = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
# ignore_index=255
# labels[labels.data[::] == ignore_index] = 0
labels_real = labels_real.scatter_(1, labels.data.long().cuda(), 1.0)

PyTorch使用F.cross_entropy报错Assertion `t >= 0 && t < n_classes` failed问题记录_第4张图片

参考资料

[1] torch.nn.functional — PyTorch 1.8.0 documentation
[2] Pytorch里的CrossEntropyLoss详解 - marsggbo - 博客园
[3] RuntimeError: cuda runtime error (59) : device-side assert triggered when running transfer_learning_tutorial · Issue #1204 · pytorch/pytorch
[4] PyTorch 中,nn 与 nn.functional 有什么区别? - 知乎
[5] FaceParsing.PyTorch/augmentations.py at master · TracelessLe/FaceParsing.PyTorch

你可能感兴趣的:(#,深度学习框架,python,深度学习,人工智能,debug)