pytorch构建LENET分类MNIST

LENET卷积神经网络

import torch
import torch.nn as nn
import torch.utils.data as Data
from torchvision import datasets,transforms
import torchvision
from torch.autograd import Variable
import numpy
# 加载数据和数据处理
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.5,],std=[0.5,])])
train_data = datasets.MNIST(root="D:\jupyter_data",  transform=transform, train=True)
test_data = datasets.MNIST(root="D:\jupyter_data", transform=transform, train=False)
batch = 256
train_iter = Data.DataLoader(dataset=train_data, batch_size=batch, shuffle=True)
test_iter = Data.DataLoader(test_data,batch, shuffle=True)

1.模型

#卷积块,包含卷积->激活->池化->卷积->激活->池化
conv = nn.Sequential(
    nn.Conv2d(1, 6, 5), #输入通道1,输出通道6,卷积核5
    nn.Sigmoid(),
    nn.MaxPool2d(2,2), #2x2池化,步长2
    nn.Conv2d(6, 16, 5),
    nn.Sigmoid(),
    nn.MaxPool2d(2,2)
)
#print(conv)
'''
    这里要注意,卷积块输出的是4维特征,(批量⼤⼩, 通道, ⾼, 宽),要转为1维特征才能进入到全连接块
'''
class transDe(nn.Module):
    def __init__(self):
        super(transDe,self).__init__()
    def forward(self,X):
        return X.view(X.shape[0], -1)
transDe = transDe()

#全连接块,包含3个全连接层,先将输入变平,然后输出分类结果
fc = nn.Sequential(
    nn.Linear(16*4*4, 120), #输入16*4*4的原因,16个通道,每个通道包含4*4的特征,4是因为原来输入的mnist数据特征为28x28,经过第一次卷积28-5+1=24,池化后变为24/2=12,第二层卷积12-5+1=8,池化后为4
    nn.Sigmoid(),
    nn.Linear(120,10)
)
#print(fc)
net = nn.Sequential()
net.add_module('conv',conv)
net.add_module('trans',transDe)
net.add_module('fc',fc)
print(net)
Sequential(
  (conv): Sequential(
    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): Sigmoid()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (4): Sigmoid()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (trans): transDe()
  (fc): Sequential(
    (0): Linear(in_features=256, out_features=120, bias=True)
    (1): Sigmoid()
    (2): Linear(in_features=120, out_features=10, bias=True)
  )
)

2.定义损失函数和优化方法

loss = nn.CrossEntropyLoss()

3.训练

  • 不想烧cpu,就跑了这么一点,看到准确率提高还是很明显的
epoch_size = 5
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
for epoch in range(0, epoch_size+1):
    train_acc ,n ,train_loss = 0, 0, 0
    i=0
    for X,y in train_iter:
        y_pre = net(X)
        l = loss(y_pre, y).sum()
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
        
        train_acc += (y_pre.argmax(dim=1) == y).sum().item()
        n += y.shape[0]
        train_loss += l.item()
        if(i==20): #一次训练256*20张图
            break
        i+=1
    print('epoch:%d, loss:%.4f, train_acc:%.4f'%(epoch+1, train_loss/n, train_acc/n))
epoch:1, loss:0.0091, train_acc:0.1324
epoch:2, loss:0.0080, train_acc:0.3242
epoch:3, loss:0.0041, train_acc:0.6819
epoch:4, loss:0.0020, train_acc:0.8432
epoch:5, loss:0.0014, train_acc:0.8904
epoch:6, loss:0.0010, train_acc:0.9198

学习总结

全连接块加上一个线性层(实际上LeNet本身是包含这个线性层的),训练结果确实比原来的好一些

fc2 = nn.Sequential(
    nn.Linear(16*4*4, 120),
    nn.Sigmoid(),
    nn.Linear(120,84), #新加的
    nn.Sigmoid(),
    nn.Linear(84,10)
)
#print(fc)
net2 = nn.Sequential()
net2.add_module('conv',conv)
net2.add_module('trans',transDe)
net2.add_module('fc',fc2)
epoch_size = 5
optimizer = torch.optim.Adam(net2.parameters(), lr=0.01)
for epoch in range(0, epoch_size+1):
    train_acc ,n ,train_loss = 0, 0, 0
    i=0
    for X,y in train_iter:
        y_pre = net2(X)
        l = loss(y_pre, y).sum()
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
        
        train_acc += (y_pre.argmax(dim=1) == y).sum().item()
        n += y.shape[0]
        train_loss += l.item()
        if(i==20): #一次训练256*20张图
            break
        i+=1
    print('epoch:%d, loss:%.4f, train_acc:%.4f'%(epoch+1, train_loss/n, train_acc/n))
epoch:1, loss:0.0072, train_acc:0.4317
epoch:2, loss:0.0022, train_acc:0.8551
epoch:3, loss:0.0012, train_acc:0.9109
epoch:4, loss:0.0008, train_acc:0.9433
epoch:5, loss:0.0007, train_acc:0.9451
epoch:6, loss:0.0006, train_acc:0.9548

你可能感兴趣的:(学习,pytorch,机器学习)