在PyTorch框架下使用F.cross_entropy()函数时,偶尔会报错ClassNLLCriterion ··· Assertion `t >= 0 && t < n_classes ` failed
。
错误信息类似下面打印信息:
/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/ClassNLLCriterion.cu:52: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [0,0,0] Assertion `t >= 0 && t < n_classes` failed.
/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/ClassNLLCriterion.cu:52: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [1,0,0] Assertion `t >= 0 && t < n_classes` failed.
/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/ClassNLLCriterion.cu:52: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [2,0,0] Assertion `t >= 0 && t < n_classes` failed.
/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/ClassNLLCriterion.cu:52: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [3,0,0] Assertion `t >= 0 && t < n_classes` failed.
THCudaCheck FAIL file=/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/generic/ClassNLLCriterion.cu line=83 error=59 : device-side assert triggered
Traceback (most recent call last):
File "tutorial.py", line 100, in <module>
model = train_model(model, criterion, optim_scheduler_ft, num_epochs=25)
File "tutorial.py", line 80, in train_model
loss = criterion(outputs, labels)
File "python3.7/site-packages/torch/nn/modules/module.py", line 206, in __call__
result = self.forward(*input, **kwargs)
File "python3.7/site-packages/torch/nn/modules/loss.py", line 313, in forward
self.weight, self.size_average)
File "python3.7/site-packages/torch/nn/functional.py", line 509, in cross_entropy
return nll_loss(log_softmax(input), target, weight, size_average)
File "python3.7/site-packages/torch/nn/functional.py", line 477, in nll_loss
return f(input, target)
File "python3.7/site-packages/torch/nn/_functions/thnn/auto.py", line 41, in forward
output, *self.additional_args)
RuntimeError: cuda runtime error (59) : device-side assert triggered at /py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/generic/ClassNLLCriterion.cu:83
通常情况下,这是由于求交叉熵函数在计算时遇到了类别错误的问题,即不满足t >= 0 && t < n_classes
条件。
在分类任务中,需要调用torch.nn.functional.cross_entropy()
函数求交叉熵,从PyTorch官网可以看到该函数定义:
torch.nn.functional.cross_entropy(input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')
可以注意到有一个key-value是ignore_index=-100。这是在交叉熵计算时被跳过的部分。通常是在数据增强中的填充值。
而在代码运行中报错ClassNLLCriterion Assertion `t >= 0 && t < n_classes ` failed
,大部分都是由于没有正确处理好label(ground truth)导致的。例如在数据增强中,填充数据使用了负数,或者使用了某大正数(如255),而在调用torch.nn.functional.cross_entropy()
方法时却没有传入正确的ignore_index。这就会导致运行过程中的Assertion Error。
import torchvision.transforms.functional as tf
tf.pad(cropped_img, padding_tuple, padding_mode="reflect"),
tf.affine(mask, translate=(-x_offset, -y_offset), scale=1.0, angle=0.0, shear=0.0,fillcolor=250,)
import torch
import torch.nn.functional as F
import torch.nn as nn
def cross_entropy2d(input, target, weight=None, reduction='none'):
n, c, h, w = input.size()
nt, ht, wt = target.size()
if h != ht or w != wt:
input = F.interpolate(input, size=(
ht, wt), mode="bilinear", align_corners=True)
input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
target = target.view(-1)
loss = F.cross_entropy(input, target, weight=weight, reduction=reduction, ignore_index=255)
return loss
可以看到在数据增强时的填充值为250(fillcolor=250),但在求交叉熵时却传入了ignore_index=255。因此在代码运行时,F.cross_entropy部分便会报错ClassNLLCriterion ··· Assertion `t >= 0 && t < n_classes ` failed
。只需要统一好label部分填充数据和计算交叉熵时需要忽略的class就可以避免出现这一问题。
在PyTorch框架下,使用无用label值进行填充和处理时,要注意在使用scatter_
函数时也需要注意对无用label进行提前处理,否则在使用data.scatter_()
时同样也会报类似类别index错误。
labels = labels[:, :, :].view(size[0], 1, size[1], size[2])
oneHot_size = (size[0], classes, size[1], size[2])
labels_real = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
# ignore_index=255
# labels[labels.data[::] == ignore_index] = 0
labels_real = labels_real.scatter_(1, labels.data.long().cuda(), 1.0)
[1] torch.nn.functional — PyTorch 1.8.0 documentation
[2] Pytorch里的CrossEntropyLoss详解 - marsggbo - 博客园
[3] RuntimeError: cuda runtime error (59) : device-side assert triggered when running transfer_learning_tutorial · Issue #1204 · pytorch/pytorch
[4] PyTorch 中,nn 与 nn.functional 有什么区别? - 知乎
[5] FaceParsing.PyTorch/augmentations.py at master · TracelessLe/FaceParsing.PyTorch