细粒度分类网络之WS-DAN论文阅读附代码

论文阅读

细粒度分类 (FGVC) 是为了解决“类内分类”问题,有别于猫狗分类,它要解决的是 [这只狗是萨摩还是哈士奇] 这种问题。这类问题的特点是类别之间的区别较小,本人从事的瑕疵检测也是属于这一领域,有瑕疵的样本与正常样本往往区别很小,用普通的分类网络并不能达到很好的效果,这篇论文中介绍的网络亲测比普通的分类网络效果更好。

该论文提出了一种针对细粒度视觉分类任务的方法,采用基于弱监督学习的图像增强方法,结合注意力机制,这使得网络可以在不需要额外标注信息的情况下聚焦到那些图像中“有话语权”的部分,在细粒度分类问题中达到 state-of-art 的水准。

论文地址:https://arxiv.org/pdf/1901.09891.pdf​

论文的前面两小节,作者大概介绍了一下他们发这篇论文做的工作以及业界对 FGVC 问题的进展。精华从第三小节开始:

1.训练过程

细粒度分类网络之WS-DAN论文阅读附代码_第1张图片

上图是整个网络的训练过程,也是整片论文的核心。训练过程分成了**(A)Weakly Supervised Attention Learning** 以及 (B) Attention-Guided Data Augmentation 两部分,下面分别讲解这两部分

  • (A)Weakly Supervised Attention Learning

这一步是基于弱监督的注意力区域学习。首先,网络会对原始图片基础 CNN 进行特征提取,特征提取网络默认使用 inceptionV3,当然我们也可以用其他网络。提取到的特征文中称为 Feature maps,随后 Feature maps 经过一个kernel size 为 1 的卷积运算得到 Attention maps,就是说 Attention maps 是由 Feature maps 降维之后得到的,具体降到多少维度 M 是一个超参数可以自行配置。根据作者描述,M 个 Attention map 中每一个都代表了物体的一个位置例如鸟的头部,飞机的机翼等。后面网络还会根据 Attention map 对图片进行针对性的增强。

细粒度分类网络之WS-DAN论文阅读附代码_第2张图片

在得到 Feature maps 和 Attention maps 之后,作者受 Bilinear Pooling 的启发,提出了 Bilinear Attention Pooling,简称 BAP,如上图中所示,具体操作是将 Feature maps 与每个 channel 的 Attention map 按元素相乘,如下式。相乘之后再经过池化降维以及拼接操作获得最后的 Feature Matrix,这是最后线性分类层的输入。

F k = A k ⊙ F ( k = 1 , 2 , . . . , M ) F_k = A_k \odot F(k = 1, 2, ..., M) Fk=AkF(k=1,2,...,M)

  • (B) Attention-Guided Data Augmentation

这一步是用之前步骤获得的 Attention map 来指导数据增强,这会比普通的随机数据增强更有优势,将Attention map 提取的部位放大作为增强后的数据进行训练,为细粒度分类这一问题提供了有效的解决方式。

细粒度分类网络之WS-DAN论文阅读附代码_第3张图片

在上面的步骤中,我们获得了 M 个 Attention Map,网络会在 M 个中随机选取一个作为后面做数据增强的依据,至于为啥随机选取我的理解是第一可以增加鲁棒性,第二是可以对多个物体部位做到“雨露均沾”。随机选取一个 Attention Map 之后先对其归一化以方便后续的操作。

细粒度分类网络之WS-DAN论文阅读附代码_第4张图片

现在可以根据 Attention Map 生成 Crop Mask 了,Crop Mask 个人理解为截图的策略,文中策略是将 A k ∗ A_k^* Ak 中大于阈值 θ c \theta _c θc 的元素置为 1 ,其他置为 0,这一块为 1 的区域就是我们细粒度分类中需要的细节区域,将它上采样至模型输入的图片大小,当作一个新的“样本”输入对模型进行训练,以强制模型“注意”这些细节区域。上面的 θ c \theta _c θc 作为一个超参数也是可以根据具体问题进行调节的,文中默认为 0.5。

Attention Dropping 与 Attention Cropping 类似,将 Attention Map 中小于阈值 θ d \theta_d θd 的元素置为 1 ,其他为 0 。加入这个操作是因为作者发现不同的 Attention Maps 可能聚焦了物体相同的部位,为了让模型也可以注意到其他位置,比如上图中的 Attention Map 是鸟的头部,该操作就可以让模型注意到鸟的其他部位,就像是在告诉模型,看看啊,除了头,身体长这样的也是某某种鸟啊。Attention Dropping 操作让模型提高了 0.6% 的准确率。

训练过程还有一个很新颖的点是损失函数的设计,除了计算预测结果的交叉熵损失之外,作者为了每次各个 Attention Map 可以找到相同的物体部位,还加入了特征图与部位中心的平方差之和作为惩罚项,如下公式,这就会让每个特征图固定到每个部位的中心。其中部位中心也是每次学习到的特征图来更新的,这种设计真的很妙!

细粒度分类网络之WS-DAN论文阅读附代码_第5张图片

2. 预测过程

细粒度分类网络之WS-DAN论文阅读附代码_第6张图片

预测过程依然分为两个部分,最终预测结果是两个子结果的平均值。

  1. 第一步,原始图片输入训练好的模型中得到属于各个类别的概率,以及 Attention Maps

  2. 第二步,将第一步中得到的 M 个 Attention Maps 取平均值,注意这里不是像训练过程里面随机取一个区域,我的理解是这里如果随机取的话,可能会导致模型不稳定,每次的预测结果不一样。下面就是与训练过程类似了,根据 Attention Maps 的平均值 A m A_m Am 画出截取框,将截取框上采样再放入训练好的网络中,得到“注意力区域”属于各个类别的概率。

  3. 最后一步将上面两步的结果取平均值得到最后的分类结果

后面作者对比了一下该算法与现有的算法在各个数据集上的表现:

细粒度分类网络之WS-DAN论文阅读附代码_第7张图片

细粒度分类网络之WS-DAN论文阅读附代码_第8张图片

细粒度分类网络之WS-DAN论文阅读附代码_第9张图片

可以看出,该算法在各个细粒度分类数据集上的表现都比现有的分类算法有所提升,大家如果有细粒度分类的任务也可以试试该算法。

代码

pytorch 实现的代码地址:https://github.com/GuYuc/WS-DAN.PyTorch

这个代码一个大神写的,有注释很容易看懂,跑起来也ok,如果有问题的话可以留言我们一起交流。

待续

后面想尝试将其中的特征提取网络换成 se_ResNeXt ,这也是在细粒度分类领域中常用的算法,如果结合的话也许会有更好的效果,试试才知道。

PS:欢迎关注我的个人微信公众号 [MachineLearning学习之路],每周一篇 CV 方向的论文解读奉上!

你可能感兴趣的:(细粒度分类网络之WS-DAN论文阅读附代码)