LeNet 神经网络

文章目录

  • 1. LeNet 简介
  • 2. LeNet 的 PyTorch 实现

1. LeNet 简介

  LeNet 原是 LeNet1 - LeNet5 这系列网络的合称,但现在所说的 LeNet 则一般特指 LeNet5 (后文统一称为 LeNet)。LeNet 是 Yann LeCun 教授于 1998 年在论文《Gradient-Based Learning Applied to Document Recognition》中提出的 ,设计之初只是用于手写数字的识别,到如今已成为卷积神经网络的 HelloWorld。受限于计算机的算力不足,加之支持向量机 (核学习方法) 的兴起,CNN 方法并未成为当时学术界认可的主流方法。

  算上输入层的话,LeNet 共有 8 层,包含 3 个卷积层,2 个池化层 (下采样层) 和 1 个全连接层。其中,所有卷积操作的核都固定为 5x5,步长为 1;统一使用全局平均池化。LeNet 的网络结构如下

LeNet 神经网络_第1张图片
LeNet 神经网络_第2张图片
  • 输入层:输入图像的尺寸为 32X32;
  • C1 层 (卷积层):使用 6 个核大小为 5×5 的卷积,得到 6 张 28×28 的特征图;
  • S2 层 (池化层,即下采样层):使用 6 个 2×2 的平均池化,得到 6 张 14×14 的特征图;
  • C3 层 (卷积层):使用 16 个核大小为 5×5 的卷积,得到 16 张 10×10 的特征图;
  • S4 层 (池化层):使用 16 个 2×2 的平均池化,得到 16 张 5×5 的特征图;
  • C5 层 (卷积层):使用 120 个核大小为 5×5 的卷积,得到 120 张 1×1 的特征图 (一个向量);
  • F6 层 (全连接层):含 84 个节点的全连接层,对应于一个 7x12 的比特图;
  • 输出层:含 10 个节点的全连接层,分别代表数字 0 到 9。

2. LeNet 的 PyTorch 实现

# _*_coding:utf-8_*_
import torch
import torch.nn as nn
import torch.nn.functional as F


class LeNet5(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(LeNet5, self).__init__()

        # 卷积神经网络
        self.features = nn.Sequential(
            nn.Conv2d(in_channels, 6, kernel_size=5),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(6, 16, kernel_size=5),
            nn.MaxPool2d(kernel_size=2)  # 原模型使用的是平均池化
        )
        # 分类器
        self.classifier = nn.Sequential(
            nn.Linear(16 * 5 * 5, 120 * 1 * 1),  # 这里将第三个卷积层看成是全连接层
            nn.Linear(120, 84),
            nn.Linear(84, out_channels)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)  # 铺平tensor
        x = self.classifier(x)
        out = F.softmax(x, 1)  # 激活函数

        return out


if __name__ == "__main__":
    batch_size = 1
    in_channels = 1
    out_channels = 10
    inputs = torch.rand((batch_size, in_channels, 32, 32))  # (B, C, H, W)
    lenet = LeNet5(in_channels, out_channels)
    outputs = lenet(inputs)
    print(outputs)
    print(outputs.sum())
    print(outputs.max())
    print(torch.argmax(outputs, 1))

【参考】

  1. 卷积神经网络之Lenet;
  2. CNN发展简史——LeNet(一);
  3. pytorch LeNet 模型;

你可能感兴趣的:(神经网络浅析,LeNet)