Non-local_pytorch代码解读

参考https://zhuanlan.zhihu.com/p/33345791以及https://github.com/AlexHex7/Non-local_pytorch

图片.png

代码总的模型框架

from torch import nn
# from lib.non_local_concatenation import NONLocalBlock2D
# from lib.non_local_gaussian import NONLocalBlock2D
from lib.non_local_embedded_gaussian import NONLocalBlock2D
# from lib.non_local_dot_product import NONLocalBlock2D


class Network(nn.Module):
   def __init__(self):
       super(Network, self).__init__()

       self.convs = nn.Sequential(
           nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1),
           nn.BatchNorm2d(32),
           nn.ReLU(),
           nn.MaxPool2d(2),

           NONLocalBlock2D(in_channels=32),
           nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
           nn.BatchNorm2d(64),
           nn.ReLU(),
           nn.MaxPool2d(2),

           NONLocalBlock2D(in_channels=64),
           nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
           nn.BatchNorm2d(128),
           nn.ReLU(),
           nn.MaxPool2d(2),
       )

       self.fc = nn.Sequential(
           nn.Linear(in_features=128*3*3, out_features=256),
           nn.ReLU(),
           nn.Dropout(0.5),

           nn.Linear(in_features=256, out_features=10)
       )

   def forward(self, x):
       batch_size = x.size(0)
       output = self.convs(x).view(batch_size, -1)
       output = self.fc(output)
       return output

if __name__ == '__main__':
   import torch

   img = torch.randn(3, 1, 28, 28)
   net = Network()
   out = net(img)
   print(out.size())

model主框架是对于minst数据集的分类,只不过中间加入了Non-local模块,剩下的进入Non-local模块代码学习
先从上面图里的框架看起吧
non_local_embedded_gaussian

def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
        super(_NonLocalBlockND, self).__init__()

        assert dimension in [1, 2, 3]

        self.dimension = dimension
        self.sub_sample = sub_sample

        self.in_channels = in_channels
        self.inter_channels = inter_channels

        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1

        if dimension == 3:
            conv_nd = nn.Conv3d
            max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
            bn = nn.BatchNorm3d
        elif dimension == 2:
            conv_nd = nn.Conv2d
            max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
            bn = nn.BatchNorm2d
        else:
            conv_nd = nn.Conv1d
            max_pool_layer = nn.MaxPool1d(kernel_size=(2))
            bn = nn.BatchNorm1d

        self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                         kernel_size=1, stride=1, padding=0)

        if bn_layer:
            self.W = nn.Sequential(
                conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
                        kernel_size=1, stride=1, padding=0),
                bn(self.in_channels)
            )
            nn.init.constant_(self.W[1].weight, 0)
            nn.init.constant_(self.W[1].bias, 0)
        else:
            self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
                             kernel_size=1, stride=1, padding=0)
            nn.init.constant_(self.W.weight, 0)
            nn.init.constant_(self.W.bias, 0)

        self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                             kernel_size=1, stride=1, padding=0)
        self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                           kernel_size=1, stride=1, padding=0)

        if sub_sample:
            self.g = nn.Sequential(self.g, max_pool_layer)
            self.phi = nn.Sequential(self.phi, max_pool_layer)

init函数主要做了往常做的以及对bn的初始化,以及定义了theta和phi两个1*1卷积,sub_sample根据参数是否加入max_pooling
然后下面是主要的代码

def forward(self, x):
        '''
        :param x: (b, c, t, h, w)
        :return:
        '''

        batch_size = x.size(0)  

        g_x = self.g(x).view(batch_size, self.inter_channels, -1)
        g_x = g_x.permute(0, 2, 1)

        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
        theta_x = theta_x.permute(0, 2, 1)
        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
        f = torch.matmul(theta_x, phi_x)
        f_div_C = F.softmax(f, dim=-1)

        y = torch.matmul(f_div_C, g_x)
        y = y.permute(0, 2, 1).contiguous()
        y = y.view(batch_size, self.inter_channels, *x.size()[2:])
        W_y = self.W(y)
        z = W_y + x
       return z

1.把hwt(3维),hw(2维) 放到一起,归为一个维度W
2.g_x=BWC,theta_x=BWC,phi_x=BCW,f=BWW,f_div_C=BWW
3.y=BWC-> BCW-> BChw
4.W_y 也是1
1卷积 ,bn可选择加,最后一个残差连接
从代码的角度来看,就是先用三次1*1卷积,然后其中两次进行相乘,然后softmax类似进行映射操作,形成一个WW大小的权重,然后用第三个再相乘,有点类似attention的操作,也就是加了层系数加权,其实这个操作也有点类似全连接,计算参数同样很大,多了一层相似性。
然后再来看non_local_concatenation的写法,主要罗列区别性

  self.concat_project = nn.Sequential(
            nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False),
            nn.ReLU()
 def forward(self, x):
        '''
        :param x: (b, c, t, h, w)
        :return:
        '''

        batch_size = x.size(0)

        g_x = self.g(x).view(batch_size, self.inter_channels, -1)
        g_x = g_x.permute(0, 2, 1)

        # (b, c, N, 1)
        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1)
        # (b, c, 1, N)
        phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1)

        h = theta_x.size(2)
        w = phi_x.size(3)
        theta_x = theta_x.repeat(1, 1, 1, w)
        phi_x = phi_x.repeat(1, 1, h, 1)

        concat_feature = torch.cat([theta_x, phi_x], dim=1)
        f = self.concat_project(concat_feature)
        b, _, h, w = f.size()
        f = f.view(b, h, w)

        N = f.size(-1)
        f_div_C = f / N

        y = torch.matmul(f_div_C, g_x)
        y = y.permute(0, 2, 1).contiguous()
        y = y.view(batch_size, self.inter_channels, *x.size()[2:])
        W_y = self.W(y)
        z = W_y + x
        return z
)

主要不同点在于形成权重的方式上,这里采用维度上连接的方式进行的,操作,由于维度不同,所以先要弄得相同,然后再连接,连接后1*1卷积进行降维,f_div_C = f / N,感觉没什么用。。。可能数值太大,压缩一下吧,毕竟W很大。。。
然后来看non_local_dot_product

 def forward(self, x):
        '''
        :param x: (b, c, t, h, w)
        :return:
        '''

        batch_size = x.size(0)

        g_x = self.g(x).view(batch_size, self.inter_channels, -1)
        g_x = g_x.permute(0, 2, 1)

        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
        theta_x = theta_x.permute(0, 2, 1)
        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
        f = torch.matmul(theta_x, phi_x)
        N = f.size(-1)
        f_div_C = f / N

        y = torch.matmul(f_div_C, g_x)
        y = y.permute(0, 2, 1).contiguous()
        y = y.view(batch_size, self.inter_channels, *x.size()[2:])
        W_y = self.W(y)
        z = W_y + x

        return z

相比embedded_gaussian,直接采用了点乘,形成W*W的权重,没有使用softmax
最后看下non_local_gaussian

 self.phi = max_pool_layer
 def forward(self, x):
        '''
        :param x: (b, c, t, h, w)
        :return:
        '''

        batch_size = x.size(0)

        g_x = self.g(x).view(batch_size, self.inter_channels, -1)

        g_x = g_x.permute(0, 2, 1)

        theta_x = x.view(batch_size, self.in_channels, -1)
        theta_x = theta_x.permute(0, 2, 1)

        if self.sub_sample:
            phi_x = self.phi(x).view(batch_size, self.in_channels, -1)
        else:
            phi_x = x.view(batch_size, self.in_channels, -1)

        f = torch.matmul(theta_x, phi_x)
        f_div_C = F.softmax(f, dim=-1)

        y = torch.matmul(f_div_C, g_x)
        y = y.permute(0, 2, 1).contiguous()
        y = y.view(batch_size, self.inter_channels, *x.size()[2:])
        W_y = self.W(y)
        z = W_y + x

return z

相比embedded_gaussian,把1个1*1卷积换成了max_pooling层,但这两种区别,我还不是很清楚,等看了论文再来补充。

你可能感兴趣的:(Non-local_pytorch代码解读)