在运行pytorch模型进行训练时,CUDA报错:
"/pytorch/aten/src/THC/THCTensorScatterGather.cu:188: void THCudaTensor_scatterFillKernel(TensorInfo, TensorInfo, Real, int, IndexType) [with IndexType = unsigned int, Real = float, Dims = -1]: block: [536,0,0], thread: [319,0,0] Assertion `indexValue >= 0 && indexValue < tensor.sizes[dim]` failed."
cuDNN error: CUDNN_STATUS_NOT_INITIALIZED
"/pytorch/aten/src/THC/THCTensorScatterGather.cu:188: void THCudaTensor_scatterFillKernel(TensorInfo, TensorInfo, Real, int, IndexType) [with IndexType = unsigned int, Real = float, Dims = -1]: block: [4,0,0], thread: [415,0,0] Assertion `indexValue >= 0 && indexValue < tensor.sizes[dim]` failed."
CUDA error: device-side assert triggered
在错误提示代码/pytorch/aten/src/THC/THCTensorScatterGather.cu:188中,可以看出是pytorch中的ScatterGather报错,而这是因为代码中使用scatter_()函数导致的,
这是由于scatter_(dim, index, src, value)函数在执行时,索引index张量中的数值超过了维度参数dim所在维度的最大值,造成越界报错。
而这种错误可能会导致CUDA的cuDNN error: CUDNN_STATUS_NOT_INITIALIZED与CUDA error: device-side assert triggered两种错误。
在使用scatter_()函数时,要注意index张量中的最大值<= dim维度的最大值。
对于scatter_(dim, index, src_tensor, value),其中src_tensor与value是二选一,有且仅有一个;index是long型的tensor。
例如:
img = torch.randint(151, (1, 1, 256, 256))
size = img.size()
img_b_size = (size[0], 151, size[2], size[3])
img_b = torch.Tensor(torch.Size(img_b_size)).zero_()
img_b = img_b.scatter_(1, img.data.long(), 1.0)
# 其中第一个参数1是指dim=1,即第二个维度,对于图像来说就是通道维度,将在第二个维度进行数值的映射
# 第二个参数img.data.long()即是index,其数值范围的最大值要<= img_b的通道数量的最大值,此处img_b的通道数为151
# 即max(img.data.long())<=(151-1)=150, 因为索引都是从0开始
# 第三个参数1.0即value参数,因为有value参数,所以src_tensor参数就不用了
如果是提供src_tensor参数,同样src_tensor的数值要<=img_b所需要投射所在维度的最大值。
例如:
>>> torch.zeros(3, 5).scatter_(0, torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
0.4319 0.4711 0.8486 0.8760 0.2355
0.0000 0.6500 0.0000 0.8573 0.0000
0.2609 0.0000 0.4080 0.0000 0.1029
[torch.FloatTensor of size 3x5]
>>> z = torch.zeros(2, 4).scatter_(1, torch.LongTensor([[2], [3]]), 1.23)
>>> z
0.0000 0.0000 1.2300 0.0000
0.0000 0.0000 0.0000 1.2300
[torch.FloatTensor of size 2x4]
详细关于scatter_()函数的使用方式可参考Pytorch文档:
https://pytorch-cn.readthedocs.io/zh/latest/package_references/Tensor/#scatter_input-dim-index-src-tensor