CNN卷积神经网络之SKNet及代码

CNN卷积神经网络之SKNet及代码

  • 前言
  • SK Convolution细节
  • 网络结构
  • 实验结果
  • 代码

《Selective Kernel Networks》
论文地址:https://arxiv.org/pdf/1903.06586.pdf

前言

CVPR2019 SKNet是SENet的加强版,是attention机制中的与SE同等地位的一个模块,可以方便地添加到现有的网络模型中,对分类问题,分割问题有一定的提升。如果不清楚SENet的,可以先看一下CNN卷积神经网络之SENet

SK Convolution细节

CNN卷积神经网络之SKNet及代码_第1张图片

可以动态的选择融合不同尺度卷积核的特征图。当然这里可以不止两个分支,还可以很多个分支。

网络结构

CNN卷积神经网络之SKNet及代码_第2张图片
M代表多少路分支,G代表分组卷积,r负责控制上图中Z的维度d大小,L=32保证最小为32维:

在这里插入图片描述

实验结果

图像分类结果:
CNN卷积神经网络之SKNet及代码_第3张图片

不同卷积核组合:
CNN卷积神经网络之SKNet及代码_第4张图片

此外,作者还发现,在大多数通道中,当目标增大时,5x5卷积核所占的权重也增加,和之前的猜想是一致的。另一个的发现是,此现象只存在中浅层中。

代码

import torch.nn as nn
import torch

class SKConv(nn.Module):
    def __init__(self, features, WH, M, G, r, stride=1, L=32):
        super(SKConv, self).__init__()
        d = max(int(features / r), L)
        self.M = M
        self.features = features
        self.convs = nn.ModuleList([])
        for i in range(M):
            # 使用不同kernel size的卷积
            self.convs.append(
                nn.Sequential(
                    nn.Conv2d(features,
                              features,
                              kernel_size=3 + i * 2,
                              stride=stride,
                              padding=1 + i,
                              groups=G), nn.BatchNorm2d(features),
                    nn.ReLU(inplace=False)))
            
        self.fc = nn.Linear(features, d)
        self.fcs = nn.ModuleList([])
        for i in range(M):
            self.fcs.append(nn.Linear(d, features))
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        for i, conv in enumerate(self.convs):
            fea = conv(x).unsqueeze_(dim=1)
            if i == 0:
                feas = fea
            else:
                feas = torch.cat([feas, fea], dim=1)
        fea_U = torch.sum(feas, dim=1)
        fea_s = fea_U.mean(-1).mean(-1)
        fea_z = self.fc(fea_s)
        for i, fc in enumerate(self.fcs):
            print(i, fea_z.shape)
            vector = fc(fea_z).unsqueeze_(dim=1)
            print(i, vector.shape)
            if i == 0:
                attention_vectors = vector
            else:
                attention_vectors = torch.cat([attention_vectors, vector],
                                              dim=1)
        attention_vectors = self.softmax(attention_vectors)
        attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1)
        fea_v = (feas * attention_vectors).sum(dim=1)
        return fea_v

if __name__ == "__main__":
    t = torch.ones((32, 256, 24,24))
    sk = SKConv(256,WH=1,M=2,G=1,r=2)
    out = sk(t)
    print(out.shape)

上一篇:CNN卷积神经网络之SENet及代码

你可能感兴趣的:(CNN卷积神经网络,cnn,深度学习,机器学习)