AlexNet网络和图解

1、Alexnet网络介绍

AlexNet网络和图解_第1张图片
第一,AlexNet第一层中的卷积窗口形状是 11×11 。第二层中的卷积窗口形状减小到 5×5 ,之后全采用 3×3 。此外,第一、第二和第五个卷积层之后都使用了窗口形状为 3×3 、步幅为2的最大池化层。

紧接着最后一个卷积层的是两个输出个数为4096的全连接层。这两个巨大的全连接层带来将近1 GB的模型参数。由于早期显存的限制,最早的AlexNet使用双数据流的设计使一个GPU只需要处理一半模型。幸运的是,显存在过去几年得到了长足的发展,因此通常我们不再需要这样的特别设计了。

第二,AlexNet将sigmoid激活函数改成了更加简单的ReLU激活函数。一方面,ReLU激活函数的计算更简单,例如它并没有sigmoid激活函数中的求幂运算。另一方面,ReLU激活函数在不同的参数初始化方法下使模型更容易训练。这是由于当sigmoid激活函数输出极接近0或1时,这些区域的梯度几乎为0,从而造成反向传播无法继续更新部分模型参数;而ReLU激活函数在正区间的梯度恒为1。因此,若模型参数初始化不当,sigmoid函数可能在正区间得到几乎为0的梯度,从而令模型无法得到有效训练。

第三,AlexNet通过丢弃法来控制全连接层的模型复杂度。

第四,AlexNet引入了大量的图像增广,如翻转、裁剪和颜色变化,从而进一步扩大数据集来缓解过拟合。我们将在后面的“图像增广”一节详细介绍这种方法。

2、代码实现

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchsummary import summary_depth
import os
import sys

class MyAlexNet(nn.Module):
    def __init__(self, in_channels, **kwargs):
        super(MyAlexNet, self).__init__(**kwargs)
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 96, kernel_size=11, stride=4), nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            # 减小卷积窗口,使用填充为2来使得输入与输出的高和宽一致,且增大输出通道数
            nn.Conv2d(96, 256, kernel_size=5, padding=2), nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            # 连续3个卷积层,且使用更小的卷积窗口。除了最后的卷积层外,进一步增大了输出通道数。
            # 前两个卷积层后不使用池化层来减小输入的高和宽
            nn.Conv2d(256, 384, kernel_size=3, padding=1), nn.ReLU(),
            nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.ReLU(),
            nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )
        self.linear = nn.Sequential(
            # 这里全连接层的输出个数比LeNet中的大数倍。使用丢弃层来缓解过拟合
            nn.Linear(5*5*256, 4096), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(4096, 4096), nn.ReLU(), nn.Dropout(0.5),
            # 输出层。由于这里使用Fashion-MNIST,所以用类别数为10,而非论文中的1000
            nn.Linear(4096, 10)
        )
    
    def forward(self, x):
        x = self.conv(x)
        x = x.reshape(-1, 5*5*256)
        x = self.linear(x)
        return x

你可能感兴趣的:(卷积,网络,深度学习,神经网络,计算机视觉)