卷积神经网络 -- PyTorch 实现系列之 LeNet(datasets: CIFAR-10)

参考文献:Gradient-Based Learning Applied to Document Recognition

摘要:

        本文利用PyTorch实现了经典神经网络LeNet。

引入:

        目前有很多博客都系统介绍了LeNet的结构以及在各个框架下的代码实现,然而本人发现其中很大一部分博文是存在比较多的问题的,经过仔细阅读LeCun的文章,我把原始LeNet的整个思路在此分享给各位,在下也是刚入门,如有疏漏或错误,请各位读者指正!

        PS:由于用的是torch.nn.Modules实现,有些部分被我简化了,带我深入学习Pytorch后,再研究一下怎么把这些忽略掉的参数再加进去吧……

数据集:

        原文使用的是MNIST,这里我是用了CIFAR-10,最大的不同点只是图片深度从1变成3(黑白和彩色的区别)。

网络结构:

        LeNet是一个由2个卷积层,2个池化层和3个全连接层组成的深度神经网络结构。

     关于池化层:文中描述为采样层,可以理解为一个带参数的AvgPooling,但是肯定不是MaxPooling。这里我用torch.nn.AvgPool2d代替,忽略掉了6+6个参数。

      关于C3:实际上是不是所有的单元都和S2的所有输出相连,这里也是简化了……

      最后一层:后面接的是一个带参数的tanh,这里也把参数忽略了。

卷积神经网络 -- PyTorch 实现系列之 LeNet(datasets: CIFAR-10)_第1张图片

主要代码如下:

print_every = 100
learning_rate = 1e-2
input_depth = 3
'''
layer_0 : [64, 3, 32, 32]   (Input)
layer_1 : [64, 6, 28, 28]   (Conv)
layer_2 : [64, 6, 14, 14]   (AvgPool)
layer_3 : [64, 16, 10, 10]  (Conv)
layer_4 : [64, 16, 5, 5]    (AvgPool)
layer_5 : [16*5*5, 120]     (Linear)
layer_5 : [120, 84]         (Linear)
layer_7 : [84, 10]          (Linear)
'''
layer_1_depth, layer_1_pad, layer_1_kernel = 6, 0,(5, 5) 
layer_2_kernel = (2,2)
layer_3_depth, layer_3_pad, layer_3_kernel= 16, 0,(5, 5) 
layer_4_kernel = (2,2)
layer_5_input, layer_5_output = 400, 120
layer_6_input, layer_6_output = 120, 84
layer_7_input, layer_7_output = 84, 10
model = nn.Sequential(
    nn.Conv2d(input_depth, layer_1_depth, layer_1_kernel, padding=layer_1_pad),
    nn.AvgPool2d(layer_2_kernel),
    nn.Sigmoid(),
    nn.Conv2d(layer_1_depth, layer_3_depth, layer_3_kernel, padding=layer_3_pad),
    nn.AvgPool2d(layer_4_kernel),
    Flatten(),
    nn.Linear(layer_5_input,layer_5_output),
    nn.Linear(layer_6_input,layer_6_output),
    nn.Linear(layer_7_input,layer_7_output),
    )

# you can use Nesterov momentum in optim.SGD
optimizer = optim.SGD(model.parameters(), lr=learning_rate,
                     momentum=0.9, nesterov=True)

iterations,losss,val_accs = train(model, optimizer, epochs = 10)

 

 

 

 

你可能感兴趣的:(经典神经网络PyTorch实现)