标题:Deep Multi-instance Networks with Sparse Label Assignment for Whole Mammogram Classification
来源:MICCAI 2017
作者:Wentao Zhu, Qi Lou, Yeeleng Scott Vang, and Xiaohui Xie
解决的问题:肿块分类
现有方法 : 多阶段集成。将mass classification分成3个模块:detection, segmentation, classification。
存在的问题:
① 除了每个模块自身的性能问题,这个框架需要单独训练每个模块。多阶段不能完全挖掘出深度神经网络的力量。
② 需要很多ground truth信息 : bbox for detection, segmentation map, mass label。标定这些GT信息,很耗人力。
③ 需要手工特征。
本文方法: end-to-end + multi-instance learning(MIL)
输入:乳腺图像,图像的标签(正常图像 Vs 恶性肿块图像)
输出:对于新的乳腺图像,判断它的标签
(ps. 此处正常图像包括:不含肿块图像&良性肿块图像)
本文方法具体介绍
MIL:假设一个bag上有多个instance。一个instance被分类为positive,bag就被认为是positive;所有instance都被分类为negative,bag才被认为是negative。对应到乳腺图像分析上,将一个image分成多个patch。一个patch被分类为恶性肿块区域,image就被认为是恶性肿块图像;所有patch都被分类为正常乳腺组织,image就被认为是正常图像。于是,本文只考虑image上是恶性肿块概率最大的那个patch,它的良恶性就决定了整幅图像的性质。
方法的整个过程(如图所示)
1) 预处理,将乳腺区域分割出来。
2)用CNN对图像提取特征。此处用的是AlexNet,提取的是它最后一个卷积层,也就是第五个卷积层的特征。由于AlexNet的输入大小限制,这里将图像resize成227*227的,经过Alex-net的5个卷积层后,得到256个6*6的feature map。
3) 对这些feature map做逻辑回归。256个feature map经过逻辑回归之后得到1个6*6的map。这个6*6 map上的每个位置,对应了feature map上相同位置的256个特征值。逻辑回归就是用这256个特征值给每个位置一个评分。评分的大小就是每个位置可能是恶性肿块的概率。此处逻辑回归的公式为:
Fi,j表示第i行第j列的256个特征值,与长度为256的权值向量a做点乘,再加偏置b,然后做sigmoid,得到[0,1]之间的一个数,作为位置(i,j)是恶性肿块的概率。权值向量a和偏置b是所有位置共享的。(ps. 此处有两个疑问:① 什么是[0,1],不应该是(0,1)吗? ② 此处的逻辑回归操作跟1*1卷积有什么不同。个人觉得没什么不同,那为什么不直接表达成1*1卷积呢?)
4)对这6*6=36个概率值r1,r2,...,rm从大到小排序,得到r1',r2',...,rm'
5)
将patch中最大的恶性肿块概率值r1'作为整幅图像是恶性肿块图像的概率,而1-r1'作为整幅图像是正常图像的概率。这个想法和MIL是契合的,我们关注概率最大的那一patch,该patch是positive的,整幅图像就positive;该patch是negative的,由于概率是从大到小排序,整幅图像就negative。
Inference阶段,如果这个最大的概率超过了一定的阈值,那么这幅图像就被认为是恶性肿块图像,否则就是正常图像。为了达到这个目标,损失函数该如何设计呢?
本文给出3种损失函数的计算方法,并给出比较。
损失函数常常是这种形式:ED (ω)是误差项,或是准确率的相反数;ER (ω)是正则化项,防止训练的过拟合。
1)Max Pooling-based Multi-instance Learning
我们只看ED (ω),这里是准确率的相反数。
当图像是恶性肿块图像,它被正确分类的概率就是它被认为是恶性肿块图像的概率,即为r1'。对于正常图像也同理可得。
对于这样的loss function,反复训练之后,将会出现:
对于正常图像,它被正确分类的概率是1-r1',经过训练,r1'越来越小。
因为r1'是最大值,r2',...,rm'也越来越小。于是正常图像上的所有patch被认为是恶性肿块的概率都越来越小,这就相当于给了图像中的每一个patch一个监督信号。这样在训练的过程中,就相当于将图像分成了一个个patch进行学习,大大增加了数据量。
对于恶性肿块图像,它被正确分类的概率是r1',经过训练,r1'越来越大。
因为r1'是最大值,r2',...,rm'并不会受到约束。于是可能存在一种情况,对于只有一个恶性肿块区域的图像,它的r1'=0.99,r2'=0.85,r2'也很大。这样虽然r2'对应的patch不是恶性肿块区域,它也被认为是恶性肿块区域,这就出现了监督信号的错误。
2) Label Assignment-based Multi-instance Learning
第一种方法的问题是只考虑到最大概率的那个patch。对于正常图像,其它的patch虽然通过概率的大小排序有所约束;但是恶性肿块图像的其它patch都没被考虑到,导致了错误的监督信号的产生。于是第二种方法想去考虑每一个patch,它不是对整幅图像赋予标签,而是给每个patch赋予标签。
对于正常图像,自然所有patch都是正常的,均赋予标签0;对于恶性肿块图像,本文认为其中有k个patch都是恶性肿块,而剩余的patch都是正常的。于是概率最大的k个patch的标签为1,其余patch标签为0。
loss function中是对一张图像上的所有patch的分类准确率进行加和。Q1',...Qm'是r1',…,rm'对应的patch。这样的话,每个patch都得到了考虑。但这里的问题是,这个k要怎么确定呢?要怎么知道每幅图像有多少个patch是肿块呢?答案是没办法确定,因为ground truth就只有整幅图像的标签。
在实验过程中,作者的做法是为k赋了一个常数值,所有图像的k都是一样的。这样肯定是不太好的,也有一些patch的监督信号是错的。但相比于第一种方法中,恶性肿块图像中大部分的patch得到了一个错误的监督信号,这里的情况已经有所改善了。
3) Sparse Multi-instance Learning
本文提出了第三种方法,它综合了前两种方法的思路,也解决了它们各自的问题。相比方法一,方法三综合考虑了所有patch的信息,相比于方法二,方法三灵活地指出patch的个数。
可以发现,这个loss function跟第一种方法的差距就在 。与第一种方法相同的是,三用 考虑了最大概率的那个patch。那么它如何去考虑其他的patch呢?注意到一个先验知识,一幅图像中肿块是很少的,1块,2块,3块就算很多了,剩下的都是正常的乳腺组织。因此,在给每个patch赋予标签时,恶性肿块为1,正常为0,那么1的个数特别少,大部分都是0,因此这是一个稀疏问题。
于是加了 这一项,这是一个稀疏项。加了这一项之后,经过loss function去反复训练网络,那么最终得到的r就非常稀疏,r1',r2',...,rm'中只有几个有值,剩下的都是0。于是对于正常图像,它用 使得所有patch的概率值变小,并通过稀疏,使得大部分的概率接近于0,只有少数patch有较小的概率值,得到的都是正确的监督信号;对于恶性肿块图像,它用 使得最大概率值变大,同时通过稀疏,使得大部分patch的概率接近于0,只有少数patch有较大的概率值,这些较大的概率值就对应着恶性肿块区域,实现了数量的估计。