CLAM:基于全幻灯片图像的数据高效和弱监督的计算病理学

CLAM:基于全幻灯片图像的数据高效和弱监督的计算病理学

前提知识:

MIL多示例学习:

CLAM处理的是多分类问题,对于输入的WSI图,根据它的形态学特征可以分为N个类别,一张WSI图只能属于其中的一个类。(比如WSI图共有三个类别:a b c,一张WSI只属于a b c中的一类)。若一张WSI图属于某一类,则这个WSI图的所划分出来的所有patch都属于这一类。

CLAM模型的流程解释:

a:

  1. 对原始WSI图像进行语义分割,去掉背景
  2. 对取出来的这个图像划分成M个Patch(每个patch256*256pixels)

b:

  1. 对于每个patch,逐个输入到ResNet50中,进行特征提取,降维,输出一个xx维度的向量z_k
  2. 对于(1)输出的向量,使其输入一个Attention backbone,输出对这个向量的attention评分
  3. 因为共有M个patch,所以(2)输出有M个向量,attention backbone还将这M个向量分成N个类,进入了N条分支。这N个类中有一类是原始WSI所属的类。
  4. 对于每个分好了类别的patch,最终将经过一系列“操作”,输出一个向量s,这个向量s有 N维,代表了每个patch所代表的类跟真实标签之间的关系轻重度。最终这个向量s再经过一个“操作”,根据已知s来预测这个WSI图的标签。最终这个预测的WSI标签和ground truth标签通过CE Loss来迭代训练,记为L_slide。
  5. 对于(2)输出的带评分的patch,将其全部按照得分由低到高排序,取最低的K个并给一个伪标签为0,取最高的K个并给一个伪标签为1,对这2K个patch做聚类,使用Smooth SVM loss来迭代训练,记为L_patch。
  6. 最终要迭代训练优化的损失函数为L_total=c1*L_slide+c2*L_patch

c:

  1. 对于(2)输出的带评分的patch,将其全部按照得分由低到高排序,取最低的K个并给一个伪标签为0,取最高的K个并给一个伪标签为1,对这2K个patch做聚类,使用Smooth SVM loss来迭代训练,记为L_patch。
  2. 通过attention pooling转换为热图

最终要迭代训练优化的损失函数为L_total=c1*L_slide+c2*L_patch

最终自己根据代码的理解 画出来的完整CLAM模型的流程图。不一定对,或者说大概率是错的。

你可能感兴趣的:(医学图像处理,人工智能)