CLAM——论文笔记

最近看了一篇有关多示例学习的paper,题目为Data Efficient and Weakly Supervised Computational Pathologyon Whole Slide Images,对里面提出的模型比较感兴趣,特此做一下笔记。
github地址:https://github.com/mahmoodlab/CLAM
paper地址:https://arxiv.org/abs/2004.09666

笔记

这篇paper提出了一个Clustering-constrained Attention Multiple Instance Learning的模型,简称为CLAM模型,整个模型的框架图如下:
CLAM——论文笔记_第1张图片
由于整篇paper对其模型框架介绍的很模糊,而且光看上面的流程图也看不明白。通过对CLAM仓库代码的研究,总算弄明白了整体模型的结构。
还是采用基于embedding的两阶段的训练方式(如果不明白什么是基于embedding,可以看我上一篇文章背景知识部分——Dual-stream multiple instance learning networks for tumor detection in Whole Slide Image——论文笔记)。

(一)特征提取部分

特征提取部分很简单,并没有采用复杂的方式,而是直接采用pytorch提供的预训练权重。仓库代码默认是采用resnet50的ImageNet预训练权重来提取特征的,这部分没什么好说的。

(二)示例分类器

示例分类器的模型结构可以简单概括为门注意力层+全连接层,整体的模型并不复杂,
训练部分可以分为以下步骤,以resnet50提取的特征为例,特征维度为1024:

  1. 使用gate_attention将M×1024的特征向量转为M×1的注意力分数和M×512的特征映射。(M为一个slide中所有的patch数量)
  2. 对注意力分数进行排序,取出最大最小的topk个分数对应的特征映射(topk默认设置为8)。
  3. 将最大topk的标签设为1,最小topk的标签设为0,作为instance标签。
  4. 对2×topk的特征映射输入N个二分类全连接层,得到N个二分类输出。(N为预测类别)
  5. 计算N个二分类输出与instance标签的SmoothTop1SVM Loss(这个loss就是instance loss)。
  6. 将注意力分数乘于对应的特征映射并将所有特征映射相加相加,得到512×1的特征向量。
  7. 将特征向量输进去全连接层,得到bag分类结果。
  8. 计算bag分类的CE Loss。(这个loss就是bag loss)
  9. 总Loss定义为:bag_weight * bag_loss+(1-bag_weight)*instance_loss。(bag_weight默认是0.8)

上面介绍的是整体流程的思路,具体的模型是怎么定义的, 怎么组合的,仓库代码都有,我就不一一赘述了,都是一些很简单模型构件,并不难理解。

(三)一些细节部分

这些细节部分都是从仓库提供的源代码推出来的,并不保证一定是正确的。

  1. dataloader的batch size是固定设为1。
  2. 无论是二分类还是多分类问题,关于计算instance loss时候的标签都是最大topk为1,最小topk为0。
  3. 仓库里面提供了两个模型,一个单分支的clam_sb,另外一个是多分支的clam_mb。 两者的区别在于:
    (1) mb计算注意力分数是计算了每个类别的分数,即N×C个分数,最后分类层也是有多少类别就有多少个全连接层,每个全连接层输出每个类别bag分数。
    (2) sb无论多少类别,计算注意力分数时都是N×1个分数,最后分类层就一个全连接层,输出为类别数。

你可能感兴趣的:(MIL,深度学习,人工智能)