记录使用mmseg时在计算交叉熵损失遇到的RuntimeError问题与解决方案

目录

问题描述:

非常心酸的绕了一个大圈的debug历程

解决方法

再总结一下其他几个容易踩雷的地方吧:

最后再来总结一下这次debug的经验心得:


问题描述:

在使用mmseg在自己的数据集上训练语义分割模型时,遇到了一个很奇怪的RuntimeError,翻遍了内网外网都没有找到合适的解决方案。

bug如下:

RuntimeError: CUDA error: an illegal memory access was encountered
    correct = correct[:, target != ignore_index]
RuntimeError: CUDA error: an illegal memory access was encountered
terminate called after throwing an instance of 'c10::Error'
  what():  CUDA error: an illegal memory access was encountered

(前面的一大堆Traceback还有后面的frame #就不写了,总之这里应该是RuntimeError的错误根源)

非常心酸的绕了一个大圈的debug历程

以下是我的debug经历,不感兴趣的可以直接跳到最后看解决方法。

先在网上查了一下,发现RuntimeError: CUDA error: an illegal memory access was encountered这个错误的原因各种各样,有些博主给出了统一的解决方式,比如缩小batch size 或者缩小image size或者调整一下并行计算策略什么的,但是我把这些方案挨个尝试了一遍以后发现统统不管用。

所以我就想,我遇到的这个RuntimeError应该不是真正的GPU内存问题。把整个运行日志看了一遍以后发现,好家伙,不要说一个iter都没跑完,根本就是一张图都没跑出来!再看看RuntimeError之前的最后一句话:

correct = correct[:, target != ignore_index]

于是我就想可能这里就是错误根源了。在mmseg的官方文档里搜索了一下这句话,发现这是ignore_index != None才会出现的。

那么问题来了,这个ignore_index到底是个嘛玩意呢?

又翻了一下官方文档和源码以后发现,在语义分割模型里,尤其是计算Cross Entropy Loss与Accuracy的过程中,可以忽略一些不重要的像素。而这个ignore_index就是被忽略的像素。

所以我出现了这个错误是因为我的数据集里有些像素被标记为ignore了,但是mmseg的源码里avg_non_ignore是默认为False的,也就是说默认采用不忽略的方式。那么矛盾就发生在了这里,有像素被忽略掉,但是又采用了不忽略的方式,那怪不得要报错。

梳理一下逻辑顺序:avg_non_ignore默认为False(即以不忽略像素的方式计算accuracy)→ignore_index理应标记为None→实际上ignore_index不为None→出现矛盾,报错。

梳理清楚了错误根源,就可以按图索骥地去找解决方案了。首先就是要找到到底哪类像素被标记为了ignore_index。继续翻mmseg的源码,发现如果用的是Cross Entropy Loss,那么会有三种选择:

如果use_sigmoid设置为True(默认为False),那么就采用BCEloss(针对二分类问题)。在这种情况下avg_non_ignore默认为False,而ignore_index默认为-100。

如果use_mask设置为True(默认为False),那么就采用Mask Entropy Loss。在这种情况下avg_non_ignore默认为False,而ignore_index则默认为None。

如果上面两个都是False,那么就采用真正的Cross Entropy Loss。在这种情况下情况下avg_non_ignore默认为False,而ignore_index默认为-100。

而我的config文件里,use_sigmoid和use_mask都是默认值,即False,那么我的模型采用的就是最后一种方案,也就是说ignore_index被默认为-100。

在我自己的数据集里,有三类像素,背景被标记为[0,0,0],还有一类是[100,100,100],另一类是[255,255,255]。所以这跟ignore_index的-100和None两个取值有什么关系啊喂!!!

那就只能换个思路了。既然因为avg_non_ignore设置为了False所以出现了Runtime Error,那把它改成True不就可以了吗?

然后就出现了另一种Runtime Error……

avg_factor = label.numel() - (label == ignore_index).sum().item()
RuntimeError: CUDA error: an illegal memory access was encountered
terminate called after throwing an instance of 'c10::Error'
  what():  CUDA error: an illegal memory access was encountered

好家伙非得把路堵死是吧?

然后这个bug就困扰了我三天的时间(非常不想面对所以逃避了一下)。在这三天的时间里,抛去大量摸鱼时间以外,我把内外网都翻遍了也没找到答案。Github上也有一些人在用mmsegmentation 计算loss(尤其是Cross Entropy Loss)的时候遇到了这个问题,但是没有一个人给出解决方法的。所以这三天里我感到无比绝望,觉得本学术废物好不容易盼来的科研春天就要葬送在这个小小的Runtime Error上了(而且这个bug真的很没有价值,所以就很气)。


逃避了三天过后我重新打起精神,理清了解决这个bug的关键。既然我的任务里完全不需要忽略部分像素的话,那就得搞清楚这个ignore_index到底是啥。于是我在mmseg/models/losses里找到了accuracy.py文件,加了一行代码:

if ignore_index is not None:
        correct = correct[:, target != ignore_index]
        print(ignore_index)#没错就是只加了这一小句

这一加可不得了,运行结果把我惊呆了,屏幕上出现了好几个255!

震惊之余我回忆了一下自己在Github上看到的类似情景,大家的数据集里都是有255的标签。原来如此……ignore_index是被默认为255的,而我自己的数据里又有像素被标为255但是又并不是需要忽略的像素,这才是问题的根源。

随后我又翻了一下mmseg的官方文档,果不其然,看我在binary_cross_entropy的源码里找到了什么:

记录使用mmseg时在计算交叉熵损失遇到的RuntimeError问题与解决方案_第1张图片

 此时此刻的心情有点像是柳暗花明又一村,但是我并不是很高兴反而很想骂人……

后来我又看了一些语义分割相关的帖子(大部分都是稀里糊涂用labelme标一下就开跑了所以说得也并不是很清楚),翻了翻官方文档,又问了一个做分割的师兄,这才知道语义分割的数据集一般并不是你想标多少就标多少的,而是按照0,1,2,3……的顺序对类别打标签的。

至于class_names.py和cityscapes.py中的PALETTE,那个只是做完分割以后给像素涂色让你看明白用的,跟你在数据集里给像素打的标签毫无关系。

……


解决方法

回到最初的问题,解决这个Runtime Error的方法就是把数据集重新标注一下,背景标为0,其他类别的像素分别标为1,2,3……就行了。看了一下其他语义分割数据集,没啥类别特别多的,总不至于标到255。现在想想就是很后悔,人家都不标255,就我傻乎乎地图自己看着方便标了0和255,都是血泪……


再总结一下其他几个容易踩雷的地方吧:

1.有的教程里会写要改cityscapes.py或者其他数据集格式文件,但是漏掉了还有一个class_names.py要改,这个文件放在mmseg/core/evaluation里,同样是改一下类别名称和PALETTE。

2.在上面两个文件里,一般不管背景,也就是说,假如你的数据集里有两类物品和背景,那么在classes和PALETTE里只需要写那两类物品的名称和你想给它俩涂的颜色就行。但是在config文件里,num_classes要写为3,因为本质上是做三分类。

3.打标注时记得在灰度图模式下标注。我中途有一次失误就是把标注打成了三通道模式,也就是本来应该标个0就行结果我标成了[0,0,0],于是又一次心酸debug……

4.要老老实实把数据集的像素级标注打成0,1,2,3……尤其是不要闲的没事标255,否则你就会555!!!


最后再来总结一下这次debug的经验心得:

1.遇到问题看官方文档,尤其是Runtime Error这种错误原因千奇百怪的问题,在网上看一百个帖子都不如看官方文档来得有效率。如果我一开始就看到了源码里的那句As the ignore_index often set as 255我还会这么惨嘛……

2.做数据集之前先了解基础的数据集格式,尤其是有些不成文的“潜规则”。很多博文主要是教怎么用工具,所以底层逻辑和框架不一定说得很清楚。或者问问有经验的人也行。

3.遇到问题想不明白钻牛角尖的时候,短暂逃避一下也是挺有用的

你可能感兴趣的:(深度学习,人工智能,python)