【GoggLeNet 的 inception 结构】 与 【ResNet 的 Residual 结构】 tensor拼接方式的区别

文章目录

      • 一、两种结构的介绍
      • 二、inception 结构的tensor拼接方法
      • 三、Residual 结构的tensor相加

参考博客:

图像分类:GoggLeNet网络、五分类 flower 数据集、pytorch

ResNet 网络结构与残差连接介绍

一、两种结构的介绍

GoggLeNet中的 inception 结构通过并联多个具有卷积核大小的卷积层,增加了网络层的宽度。

ResNet中的 Residual 结构(残差块)通过纵向的串联,增加了网络的深度。

这两种结构作为神经网络中最早出现的增加网络宽度和深度的方法,大大提高了网络的性能。尤其是ResNet的残差连接成为了非常著名的经典网络。

在这两种结构中,不论是并联和串联之后,都需要做拼接或相加的操作,那么二者有什么不同呢?先说结论:

(1)GoggLeNet中的 inception 结构的四个分支得到的tensor只有 channel 不同,CHW都是相同的。因此最终的 tensor是四个分支的tensor在 channel 通道进行相加。

(2) ResNet中的残差块的输入和输出的shape完全相同,即经过残差块之后tensor的NCHW都不变,可以直接与输入的tensor相加,相加之后的shape不变。

二、inception 结构的tensor拼接方法

GoggLeNet的网络结构采用 图像分类:GoggLeNet网络、五分类 flower 数据集、pytorch 的model.py模块。

inception结构如下:

【GoggLeNet 的 inception 结构】 与 【ResNet 的 Residual 结构】 tensor拼接方式的区别_第1张图片

inception结构有四个分支,四个输出要进行DepthConcat,即深度拼接(也就是channel方向的拼接)。在 model.py的 class Inception() 中,我们打印四个分支(branch)和经过 DepthConcat 之后的tensor的shape如下:

branch1.shape = torch.Size([32, 64, 28, 28])
branch2.shape = torch.Size([32, 128, 28, 28])
branch3.shape = torch.Size([32, 32, 28, 28])
branch4.shape = torch.Size([32, 32, 28, 28])
result.shape = torch.Size([32, 256, 28, 28])

可以看到,result_tensor 的通道数是四个 sub_tensor 通道数的直接相加。

result_tensor 是四个sub_tensor 通过 torch.cat() 拼接的,那么 torch.cat() 的作用是什么呢?看一个例子就明白了:

import torch
a = torch.randn(1, 4)
b = torch.randn(2, 4)
c = torch.cat([a, b])  # torch.Size([3, 4])
print(f"a: {a}")
print(f"b: {b}")
print(f"c: {c}")

# output:
a: tensor([[-0.1987, -1.4841, -0.7369, -1.3955]])
b: tensor([[-1.0136,  0.4432,  0.8319,  1.3025],
        [ 0.3873, -0.4107, -0.0439,  0.6455]])
c: tensor([[-0.1987, -1.4841, -0.7369, -1.3955],
        [-1.0136,  0.4432,  0.8319,  1.3025],
        [ 0.3873, -0.4107, -0.0439,  0.6455]])

可以看到 torch.cat() 做的就是矩阵拼接,将两个矩阵拼接成一个更大的矩阵,并没有做任何相加的运算。

三、Residual 结构的tensor相加

残差连接是直接相加,要求 F(x)与x 的shape完全相同,F(x)+x 是对应元素相加,看例子:

fx = torch.tensor([1, 2])
x = torch.tensor([3, 4])
print(fx + x)    # tensor([4, 6])

用pytorch实现一个最基础的残差块:

import torch
import torch.nn as nn


class ResnetBlock(nn.Module):
    def __init__(self, convChannel):
        super(ResnetBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels=convChannel, out_channels=convChannel, kernel_size=3,
                              stride=1, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(convChannel)    
        # BN层是对一个channel内的元素做归一化,BN层要放在卷积层和激活层之间
        self.relu = nn.ReLU()

    def forward(self, x):
        x_input = x
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x + x_input)    # 残差连接,先相加再做relu非线性激活
        return x


if __name__ == '__main__':
    batch_size = 4
    channel = 3
    pic_size = 100
    data = torch.randn(batch_size, channel, pic_size, pic_size)   # NCHW
    model = ResnetBlock(convChannel=channel)
    output = model(data)
    
    print(data.shape)       # torch.Size([4, 3, 100, 100])
    print(output.shape)     # torch.Size([4, 3, 100, 100])

可以看到,残差块的输入和输出的tensor的shape相同。

你可能感兴趣的:(图像分类,深度学习,机器学习,python)