参考博客:
图像分类:GoggLeNet网络、五分类 flower 数据集、pytorch
ResNet 网络结构与残差连接介绍
GoggLeNet中的 inception 结构通过并联多个具有卷积核大小的卷积层,增加了网络层的宽度。
ResNet中的 Residual 结构(残差块)通过纵向的串联,增加了网络的深度。
这两种结构作为神经网络中最早出现的增加网络宽度和深度的方法,大大提高了网络的性能。尤其是ResNet的残差连接成为了非常著名的经典网络。
在这两种结构中,不论是并联和串联之后,都需要做拼接或相加的操作,那么二者有什么不同呢?先说结论:
(1)GoggLeNet中的 inception 结构的四个分支得到的tensor只有 channel 不同,CHW都是相同的。因此最终的 tensor是四个分支的tensor在 channel 通道进行相加。
(2) ResNet中的残差块的输入和输出的shape完全相同,即经过残差块之后tensor的NCHW都不变,可以直接与输入的tensor相加,相加之后的shape不变。
GoggLeNet的网络结构采用 图像分类:GoggLeNet网络、五分类 flower 数据集、pytorch 的model.py模块。
inception结构如下:
inception结构有四个分支,四个输出要进行DepthConcat,即深度拼接(也就是channel方向的拼接)。在 model.py的 class Inception() 中,我们打印四个分支(branch)和经过 DepthConcat 之后的tensor的shape如下:
branch1.shape = torch.Size([32, 64, 28, 28])
branch2.shape = torch.Size([32, 128, 28, 28])
branch3.shape = torch.Size([32, 32, 28, 28])
branch4.shape = torch.Size([32, 32, 28, 28])
result.shape = torch.Size([32, 256, 28, 28])
可以看到,result_tensor 的通道数是四个 sub_tensor 通道数的直接相加。
result_tensor 是四个sub_tensor 通过 torch.cat() 拼接的,那么 torch.cat() 的作用是什么呢?看一个例子就明白了:
import torch
a = torch.randn(1, 4)
b = torch.randn(2, 4)
c = torch.cat([a, b]) # torch.Size([3, 4])
print(f"a: {a}")
print(f"b: {b}")
print(f"c: {c}")
# output:
a: tensor([[-0.1987, -1.4841, -0.7369, -1.3955]])
b: tensor([[-1.0136, 0.4432, 0.8319, 1.3025],
[ 0.3873, -0.4107, -0.0439, 0.6455]])
c: tensor([[-0.1987, -1.4841, -0.7369, -1.3955],
[-1.0136, 0.4432, 0.8319, 1.3025],
[ 0.3873, -0.4107, -0.0439, 0.6455]])
可以看到 torch.cat() 做的就是矩阵拼接,将两个矩阵拼接成一个更大的矩阵,并没有做任何相加的运算。
残差连接是直接相加,要求 F(x)与x 的shape完全相同,F(x)+x 是对应元素相加,看例子:
fx = torch.tensor([1, 2])
x = torch.tensor([3, 4])
print(fx + x) # tensor([4, 6])
用pytorch实现一个最基础的残差块:
import torch
import torch.nn as nn
class ResnetBlock(nn.Module):
def __init__(self, convChannel):
super(ResnetBlock, self).__init__()
self.conv = nn.Conv2d(in_channels=convChannel, out_channels=convChannel, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn = nn.BatchNorm2d(convChannel)
# BN层是对一个channel内的元素做归一化,BN层要放在卷积层和激活层之间
self.relu = nn.ReLU()
def forward(self, x):
x_input = x
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
x = self.conv(x)
x = self.bn(x)
x = self.relu(x + x_input) # 残差连接,先相加再做relu非线性激活
return x
if __name__ == '__main__':
batch_size = 4
channel = 3
pic_size = 100
data = torch.randn(batch_size, channel, pic_size, pic_size) # NCHW
model = ResnetBlock(convChannel=channel)
output = model(data)
print(data.shape) # torch.Size([4, 3, 100, 100])
print(output.shape) # torch.Size([4, 3, 100, 100])
可以看到,残差块的输入和输出的tensor的shape相同。