用pytorch再实现mnist手写识别

用pytorch再实现mnist手写识别

mnist作为非常棒的数据集,所以入pytorch的坑就先试试mnist。

数据集下载

http://yann.lecun.com/exdb/mnist/

lecun的网站上可以下载的到,也可以用pytorch的datasets包下载

from torchvision.datasets import mn
train_set = mn.MNIST('./data', train=True, download=True)
test_set = mn.MNIST('./data', train=False, download=True)

但是注意的是,如果你没有科学上网,是下载不下来的,速度像树懒一样非常非常非常慢。

预处理

预处理可以用transforms.Compose(),先除以255,再通过transforms.Normalize([0.5],[0.5])来将色彩值映射到[-1,1]闭区间里面,假设原色彩值是0,那么映射后就是-1。因为mnist图片都是灰度图,所以只有一个色彩通道。正常的彩色图片是RGB三通道的,transforms.Normalize([r,g,b],[d,e,f])这样对应每个RGB的均值和方差。

用pytorch再实现mnist手写识别_第1张图片

也可以自定义函数

def data_tf(x):
    x = np.array(x, dtype='float32') / 255
    x = (x - 0.5) / 0.5 
    x = x.reshape((-1,))#维度转化
    x = torch.from_numpy(x)
    return x
迭代器

可以在代码开头就定义一些超参数,比如批处理大小,学习率等等,这边设置的训练数据的批处理为64,设置顺序打乱

from torch.utils.data import DataLoader
train_data = DataLoader(train_set, batch_size=64, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)
a, a_label = next(iter(train_data))

定义简单两层网络

net = nn.Sequential(
    nn.Linear(784, 300),#因为28*28=748
    nn.ReLU(),
    nn.Linear(300, 10)#最后输出10个分类
)
定义loss函数,可以自定义,例如线性回归用def get_loss(y_, y): return torch.mean((y_ - y_train) ** 2)

这边用nn自带的损失函数

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), 1e-2)#学习率0.1

然后就开始训练

losses = []
acces = []
eval_losses = []
eval_acces = []

for e in range(20):
    train_loss = 0
    train_acc = 0
    net.train()
    for im, label in train_data:
        im = Variable(im)
        label = Variable(label)
        # 前向传播
        out = net(im)
        loss = criterion(out, label)
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 记录误差
        train_loss += loss.data[0]
        # 计算分类的准确率
        _, pred = out.max(1)
        num_correct = (pred == label).sum().data[0]
        acc = num_correct / im.shape[0]
        train_acc += acc
        
    losses.append(train_loss / len(train_data))
    acces.append(train_acc / len(train_data))
    # 在测试集上检验效果
    eval_loss = 0
    eval_acc = 0
    net.eval() # 将模型改为预测模式
    for im, label in test_data:
        im = Variable(im)
        label = Variable(label)
        out = net(im)
        loss = criterion(out, label)
        # 记录误差
        eval_loss += loss.data[0]
        # 记录准确率
        _, pred = out.max(1)
        num_correct = (pred == label).sum().data[0]
        acc = num_correct / im.shape[0]
        eval_acc += acc
        
    eval_losses.append(eval_loss / len(test_data))
    eval_acces.append(eval_acc / len(test_data))
    print('epoch: {}, Train Loss: {:.6f}, Train Acc: {:.6f}, Eval Loss: {:.6f}, Eval Acc: {:.6f}'
          .format(e, train_loss / len(train_data), train_acc / len(train_data), 
                     eval_loss / len(test_data), eval_acc / len(test_data)))

训练准确度和测试准确度:

用pytorch再实现mnist手写识别_第2张图片

用pytorch再实现mnist手写识别_第3张图片

改变网络层次,变成4层的

net = nn.Sequential(
    nn.Linear(784, 400),#因为28*28=748
    nn.ReLU(),
    nn.Linear(400, 200),
    nn.ReLU(),
    nn.Linear(200, 100),
    nn.ReLU(),
    nn.Linear(100, 10)#最后输出10个分类
)

当学习率设置在0.1时,Eval Acc的抖动还是很明显的,因为沿着梯度下降的方向,步子迈的太大导致函数值在极值点附近反复震荡。

epoch: 0, Train Loss: 1.408747, Train Acc: 0.590069, Eval Loss: 0.539895, Eval Acc: 0.835542
epoch: 1, Train Loss: 0.430321, Train Acc: 0.872718, Eval Loss: 0.356913, Eval Acc: 0.900613
epoch: 2, Train Loss: 0.337608, Train Acc: 0.901219, Eval Loss: 0.304655, Eval Acc: 0.906843
epoch: 3, Train Loss: 0.292559, Train Acc: 0.913963, Eval Loss: 0.258706, Eval Acc: 0.924743
epoch: 4, Train Loss: 0.256375, Train Acc: 0.924740, Eval Loss: 0.240557, Eval Acc: 0.929786
epoch: 5, Train Loss: 0.226576, Train Acc: 0.933535, Eval Loss: 0.216834, Eval Acc: 0.933445
epoch: 6, Train Loss: 0.201039, Train Acc: 0.941498, Eval Loss: 0.189868, Eval Acc: 0.942939
epoch: 7, Train Loss: 0.179223, Train Acc: 0.947994, Eval Loss: 0.169639, Eval Acc: 0.949960
epoch: 8, Train Loss: 0.162392, Train Acc: 0.952392, Eval Loss: 0.166602, Eval Acc: 0.950752
epoch: 9, Train Loss: 0.146992, Train Acc: 0.956973, Eval Loss: 0.152519, Eval Acc: 0.953026
epoch: 10, Train Loss: 0.133363, Train Acc: 0.961687, Eval Loss: 0.130080, Eval Acc: 0.961531
epoch: 11, Train Loss: 0.121915, Train Acc: 0.964702, Eval Loss: 0.123021, Eval Acc: 0.962124
epoch: 12, Train Loss: 0.112028, Train Acc: 0.967084, Eval Loss: 0.128418, Eval Acc: 0.959454
epoch: 13, Train Loss: 0.103469, Train Acc: 0.970449, Eval Loss: 0.114526, Eval Acc: 0.964300
epoch: 14, Train Loss: 0.095622, Train Acc: 0.972465, Eval Loss: 0.103326, Eval Acc: 0.967168
epoch: 15, Train Loss: 0.089247, Train Acc: 0.974130, Eval Loss: 0.095964, Eval Acc: 0.969640
epoch: 16, Train Loss: 0.083379, Train Acc: 0.975930, Eval Loss: 0.095576, Eval Acc: 0.969838
epoch: 17, Train Loss: 0.077575, Train Acc: 0.977462, Eval Loss: 0.095448, Eval Acc: 0.970332
epoch: 18, Train Loss: 0.072457, Train Acc: 0.979144, Eval Loss: 0.088108, Eval Acc: 0.972211
epoch: 19, Train Loss: 0.067587, Train Acc: 0.980027, Eval Loss: 0.086642, Eval Acc: 0.972508

用pytorch再实现mnist手写识别_第4张图片用pytorch再实现mnist手写识别_第5张图片

当学习率设置在0.01时,Eval Acc的抖动就不太明显了。

epoch: 0, Train Loss: 0.736144, Train Acc: 0.820945, Eval Loss: 0.384030, Eval Acc: 0.897251
epoch: 1, Train Loss: 0.363606, Train Acc: 0.896788, Eval Loss: 0.316678, Eval Acc: 0.910107
epoch: 2, Train Loss: 0.317855, Train Acc: 0.908699, Eval Loss: 0.284327, Eval Acc: 0.919205
epoch: 3, Train Loss: 0.290167, Train Acc: 0.915645, Eval Loss: 0.264882, Eval Acc: 0.923853
epoch: 4, Train Loss: 0.267430, Train Acc: 0.923391, Eval Loss: 0.248133, Eval Acc: 0.929094
epoch: 5, Train Loss: 0.248022, Train Acc: 0.928405, Eval Loss: 0.232285, Eval Acc: 0.932951
epoch: 6, Train Loss: 0.230525, Train Acc: 0.934518, Eval Loss: 0.215797, Eval Acc: 0.938192
epoch: 7, Train Loss: 0.214826, Train Acc: 0.939649, Eval Loss: 0.201602, Eval Acc: 0.940961
epoch: 8, Train Loss: 0.201528, Train Acc: 0.942447, Eval Loss: 0.191393, Eval Acc: 0.945016
epoch: 9, Train Loss: 0.189232, Train Acc: 0.946395, Eval Loss: 0.181257, Eval Acc: 0.947191
epoch: 10, Train Loss: 0.177858, Train Acc: 0.949743, Eval Loss: 0.175677, Eval Acc: 0.949367
epoch: 11, Train Loss: 0.168102, Train Acc: 0.951976, Eval Loss: 0.163249, Eval Acc: 0.951839
epoch: 12, Train Loss: 0.158800, Train Acc: 0.955390, Eval Loss: 0.154844, Eval Acc: 0.955795
epoch: 13, Train Loss: 0.150804, Train Acc: 0.957623, Eval Loss: 0.153768, Eval Acc: 0.954114
epoch: 14, Train Loss: 0.142895, Train Acc: 0.959905, Eval Loss: 0.146686, Eval Acc: 0.957872
epoch: 15, Train Loss: 0.136140, Train Acc: 0.962203, Eval Loss: 0.140544, Eval Acc: 0.959751
epoch: 16, Train Loss: 0.130175, Train Acc: 0.963669, Eval Loss: 0.132720, Eval Acc: 0.961828
epoch: 17, Train Loss: 0.124133, Train Acc: 0.965685, Eval Loss: 0.128871, Eval Acc: 0.962619
epoch: 18, Train Loss: 0.118933, Train Acc: 0.967267, Eval Loss: 0.122724, Eval Acc: 0.964794
epoch: 19, Train Loss: 0.113958, Train Acc: 0.968600, Eval Loss: 0.118547, Eval Acc: 0.965487

用pytorch再实现mnist手写识别_第6张图片

用pytorch再实现mnist手写识别_第7张图片

你可能感兴趣的:(python,pytorch)