5--残差网络(ResNet)

        残差网络的核心思想是:每个附加层都应该更容易地包含原始函数作为其元素之一。如图是一个正常块(左)和一个残差块(右)的结构,可以看出在残差块中,输入可通过跨层数据线路更快地向前传播。

5--残差网络(ResNet)_第1张图片

        ResNet沿用了VGG完整的3×3卷积层设计,残差块里首先有2个有相同输出通道数的3×3卷积层。 每个卷积层后接一个批量规范化层和ReLU激活函数。然后通过跨层数据通路,跳过这2个卷积运算,将输入直接加在最后的ReLU激活函数前。 这样的设计要求2个卷积层的输出与输入形状一样,从而使它们可以相加。 如果想改变通道数,就需要引入一个额外的1×1卷积层来将输入变换成需要的形状后再做相加运算。如下图是包含以及不包含 1×1 卷积层的残差块结构:

5--残差网络(ResNet)_第2张图片

        代码实现如下: 

!pip install git+https://github.com/d2l-ai/d2l-zh@release  # installing d2l
!pip install matplotlib_inline
!pip install matplotlib==3.0.0

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

class Residual(nn.Module):
  def __init__(self,input_channels,num_channels,use_1x1conv=False,strides=1):
    super().__init__()
    self.conv1 = nn.Conv2d(input_channels,num_channels,kernel_size=3,stride=strides,padding=1)
    self.conv2 = nn.Conv2d(num_channels,num_channels,kernel_size=3,padding=1)

    if use_1x1conv:
      self.conv3 = nn.Conv2d(input_channels,num_channels,kernel_size=1,stride=strides)
    else:
      self.conv3 = None
    self.bn1 = nn.BatchNorm2d(num_channels)
    self.bn2 = nn.BatchNorm2d(num_channels)
  
  def forward(self,X):
    Y = F.relu(self.bn1(self.conv1(X)))
    Y = self.bn2(self.conv2(Y))
    if self.conv3:
      X = self.conv3(X)
    Y += X
    return F.relu(Y)

b1 = nn.Sequential(
    nn.Conv2d(1,64,kernel_size=7,stride=2,padding=3),
    nn.BatchNorm2d(64),nn.ReLU(),
    nn.MaxPool2d(kernel_size=3,stride=2,padding=1))

def resnet_block(input_channels,num_channels,num_residuals,first_block=False):
  blk = []
  for i in range(num_residuals):
    if i==0 and not first_block:
      blk.append(Residual(input_channels,num_channels,use_1x1conv=True,strides=2))
    else:
      blk.append(Residual(num_channels,num_channels))
  return blk


b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))


net = nn.Sequential(b1, b2, b3, b4, b5,
                    nn.AdaptiveAvgPool2d((1,1)),
                    nn.Flatten(), nn.Linear(512, 10))

batch_size =  256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)

lr, num_epochs = 0.05, 10
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

运行结果: 

 

5--残差网络(ResNet)_第3张图片 

你可能感兴趣的:(深度学习,java,数据库,前端)