残差网络的核心思想是:每个附加层都应该更容易地包含原始函数作为其元素之一。如图是一个正常块(左)和一个残差块(右)的结构,可以看出在残差块中,输入可通过跨层数据线路更快地向前传播。
ResNet沿用了VGG完整的3×3卷积层设计,残差块里首先有2个有相同输出通道数的3×3卷积层。 每个卷积层后接一个批量规范化层和ReLU激活函数。然后通过跨层数据通路,跳过这2个卷积运算,将输入直接加在最后的ReLU激活函数前。 这样的设计要求2个卷积层的输出与输入形状一样,从而使它们可以相加。 如果想改变通道数,就需要引入一个额外的1×1卷积层来将输入变换成需要的形状后再做相加运算。如下图是包含以及不包含 1×1 卷积层的残差块结构:
代码实现如下:
!pip install git+https://github.com/d2l-ai/d2l-zh@release # installing d2l
!pip install matplotlib_inline
!pip install matplotlib==3.0.0
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
class Residual(nn.Module):
def __init__(self,input_channels,num_channels,use_1x1conv=False,strides=1):
super().__init__()
self.conv1 = nn.Conv2d(input_channels,num_channels,kernel_size=3,stride=strides,padding=1)
self.conv2 = nn.Conv2d(num_channels,num_channels,kernel_size=3,padding=1)
if use_1x1conv:
self.conv3 = nn.Conv2d(input_channels,num_channels,kernel_size=1,stride=strides)
else:
self.conv3 = None
self.bn1 = nn.BatchNorm2d(num_channels)
self.bn2 = nn.BatchNorm2d(num_channels)
def forward(self,X):
Y = F.relu(self.bn1(self.conv1(X)))
Y = self.bn2(self.conv2(Y))
if self.conv3:
X = self.conv3(X)
Y += X
return F.relu(Y)
b1 = nn.Sequential(
nn.Conv2d(1,64,kernel_size=7,stride=2,padding=3),
nn.BatchNorm2d(64),nn.ReLU(),
nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
def resnet_block(input_channels,num_channels,num_residuals,first_block=False):
blk = []
for i in range(num_residuals):
if i==0 and not first_block:
blk.append(Residual(input_channels,num_channels,use_1x1conv=True,strides=2))
else:
blk.append(Residual(num_channels,num_channels))
return blk
b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))
net = nn.Sequential(b1, b2, b3, b4, b5,
nn.AdaptiveAvgPool2d((1,1)),
nn.Flatten(), nn.Linear(512, 10))
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)
lr, num_epochs = 0.05, 10
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
运行结果: