完整代码见Github :https://github.com/capsule2077/CAM-Visualization ,如果有用可以点个Star,谢谢!
CAM是类别激活映射图,提出于论文:Learning Deep Features for Discriminative Localization ,这是一篇发表于2015年的CVPR论文。论文主要强调了卷积神经网络具有的特征定位能力,但是堆叠的全连接层会破环这种能力。为了证明这种能力,作者通过CAM技术把这种能力可视化出来,下面的图很直观,这是ResNet18
对最左上角的原图进行预测的Top5
的结果,图的上方注明了具体的类别和对应的概率值,预测结果的图以类似热力图的样式进行标注,颜色越深的地方就是模型越关注的特征,针对不同的分类情况模型的关注点都不同。
针对概率值最大的是dome(圆顶)
,可以看到颜色最深的地方汇聚于建筑物的圆顶部分。而针对palace(宫殿)
,模型的关注点更加宽泛,深颜色区域遍布整个建筑物。由此可见,CNN确实可以捕捉到圆顶的特征和位置信息。
CAM技术其实就是全局平均池化层和一个全连接层来实现;全局平均池化提出于论文:Network In Network,与最大池化相比,全局平均池化层可以保留更多的图像信息,因为它计算每个池化区域的平均值,有助于捕捉整体图像的分布信息。
全局平均池化层一般应用于网络的末端。假设有一个三通道的特征图,将红色特征图所有值算平均得到一个神经元输出,蓝绿特征图类似。有多少个通道就能输出一个相同维度的向量,而这个向量经过softmax就能得到分类的概率值。所以如果要用于ImageNet进行分类,在经过全局平均池化层前的特征图是有1000个通道的,对应1000个类别。
由于GAP(Global Average Pooling 全局平均池化层)考虑一整个通道的这种特性,那么我在GAP的后面加上一个全连接层,相当于对每个通道(因为GAP得到的向量是由每个通道全局平均得来的)进行加权,其他和普通网络一样。
举个例子,如果神经网络将图片分类成”狗“,那么把”狗“对应的全连接层权重与GAP前的特征图进行加权计算,权重值代表了某个特征对图像分类为”狗“的重要性。例如Resnet18最后一层特征图有512个通道。这512个通道可以认为提取到不同的特征,该特征具有高度抽象性,且每个通道对最后的结果贡献不同,因此单独可视化每个通道获取热图也让人很难理解。所以CAM技术根据每个通道不同的贡献大小去融合获取一张CAM,类似于热力图可以很清晰的看出网络聚焦的特征。
由于CAM技术要求网络的最后由GAP + FC组成,但是AlexNet、VGG这些网络都不符合,所以需要自行修改网络结构。如何修改网络结构参考:【Pytorch】加载预训练模型及修改网络结构
这里以修改VGG16为例:
def get_vgg(depth = 16):
# 调用torchvision中的VGG16
vgg = getattr(torchvision.models, "".join(['vgg', str(depth)]))()
# 修改维全局平均池化层,即一个通道的特征图最后只输出一个值
vgg.avgpool.output_size = (1, 1)
# 修改全连接层
vgg.classifier = nn.Linear(512, 100)
return vgg
ResNet网络是符合GAP + FC的结构的,所以如果用ResNet进行试验的话可以直接调用预训练模型。
由于修改了网络结构,所以需要对参数进行微调。微调可以只微调全连接层,其余层加载官方的预训练权重。训练时只需要将全连接层的参数传入SGD优化器即可model.classifier.parameters()
,我使用预训练的权重在MINI-ImageNet数据集上做100分类的微调 ,MINI-ImageNet参考之前的记录:【代码实验】CNN实验——利用Imagenet子集训练分类网络(AlexNet/ResNet)
微调后的权重文件:VGG16_100.pth,提取码:3v0i
...
# 训练时只需要将全连接层的参数传入SGD优化器即可
optimizer = optim.SGD(model.classifier.parameters(), lr=args.learning_rate,
momentum=0.9, weight_decay=0.0005)
...
由于需要获取GAP层前的特征图,所以需要获取模型的中间输出,参考之前的记录:【Pytorch】六行代码实现:特征图提取与特征图可视化
以上三点是前置知识,可以更好地理解代码
这里没有对CAM技术的理论进行详细解释,更推荐看原论文。如果只想应用的话,完整代码见Github :https://github.com/capsule2077/CAM-Visualization,代码中有注释,欢迎提出问题,如果有用可以点个Star,谢谢!
万字长文:特征可视化技术(CAM)
神经网络可视化——CAM及其变体