环境使用 Kaggle 里免费建立的 Notebook
教程使用李沐老师的 动手学深度学习 网站和 视频讲解
小技巧:当遇到函数看不懂的时候可以按 Shift+Tab
查看函数详解。
全连接层的问题:
卷积层需要较少的参数:
但是卷积层后的第一个全连接层的参数量:
这样会导致:
NiN 提出的思想就是在每个像素的通道上分别使用多层感知机,减少计算量。
NiN 论文地址:https://arxiv.org/abs/1312.4400
最初的 NiN 网络是在 AlexNet 后不久提出的,显然从中得到了一些启示。 NiN 使用窗口形状为 11 × 11 11\times 11 11×11、 5 × 5 5\times 5 5×5 和 3 × 3 3\times 3 3×3 的卷积层,输出通道数量与AlexNet中的相同。 每个NiN块后有一个最大池化层,池化窗口形状为 3 × 3 3\times 3 3×3,步幅为 2 2 2。
NiN 和 AlexNet 之间的一个显著区别是 NiN 完全取消了全连接层。 NiN 使用一个 NiN 块,其输出通道数等于标签类别的数量。最后放一个全局平均汇聚层(global average pooling layer),生成一个对数几率 (logits)。NiN 设计的一个优点是,它显著减少了模型所需参数的数量。然而,在实践中,这种设计有时会增加训练模型的时间。
NiN 块:
NiN 架构:
NiN 块使用卷积层加两个 1 × 1 1\times 1 1×1 卷积层,使用全局平均池化来代替 VGG 和 AlexNet 中的全连接层
!pip install -U d2l
import torch
from torch import nn
from d2l import torch as d2l
def nin_block(in_channels, out_channels, kernel_size, strides, padding):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, strides, padding),nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU())
net = nn.Sequential(
nin_block(1, 96, kernel_size=11, strides=4, padding=0),
nn.MaxPool2d(3, stride=2),
nin_block(96, 256, kernel_size=5, strides=1, padding=2),
nn.MaxPool2d(3, stride=2),
nin_block(256, 384, kernel_size=3, strides=1, padding=1),
nn.MaxPool2d(3, stride=2),
nn.Dropout(0.5),
# 标签类别数是10
nin_block(384, 10, kernel_size=3, strides=1, padding=1),
nn.AdaptiveAvgPool2d((1, 1)),
# 将四维的输出转成二维的输出,其形状为(批量大小,10)
nn.Flatten())
这里面一个小疑问:为什么有一个nn.Dropout(0.5)
。李沐老师课上提到了,但是不清楚什么意思。
各层输出形状:
X = torch.rand(size=(1, 1, 224, 224))
for layer in net:
X = layer(X)
print(layer.__class__.__name__,'output shape:\t', X.shape)
训练之前记得在 kaggle 上使用 GPU。
lr, num_epochs, batch_size = 0.1, 10, 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
结果比之前 AlexNet 精度小 0.004 0.004 0.004 。这是因为数据集小的缘故,实际上在 ImageNet 上效果是要比 AlexNet 略胜一筹的。
在模型定义中,我去掉 nn.Dropout(0.5)
然后又训练了一下:
效果反而更好了。。。不知道这个 Dropout(0.5)
有啥用