(系列更新完毕)深度学习零基础使用 PyTorch 框架跑 MNIST 数据集的第三天:训练模型

1. Introduction

今天是尝试用 PyTorch 框架来跑 MNIST 手写数字数据集的第三天,主要学习训练网络。本 blog 主要记录一个学习的路径以及学习资料的汇总。

注意:这是用 Python 2.7 版本写的代码

第一天(LeNet 网络的搭建):https://blog.csdn.net/qq_36627158/article/details/108098147

第二天(加载 MNIST 数据集):https://blog.csdn.net/qq_36627158/article/details/108119048

第三天(训练模型):https://blog.csdn.net/qq_36627158/article/details/108163693

第四天(单例测试):https://blog.csdn.net/qq_36627158/article/details/108183655

 

 

 

 

2. Code(mnist_train.py)

感谢 凯神 提供的代码与耐心指导!

from lenet import Net
import torch
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from mnist_load import testset_loader, trainset_loader


LEARNING_RATE = 0.001
MOMENTUM = 0.9
EPOCH = 5


if torch.cuda.is_available():
    device = torch.device('cuda')
    print 'cuda'
else:
    device = torch.device('cpu')
    print 'cpu'


mnist_model = Net().to(device)

optimizer = optim.SGD(
    mnist_model.parameters(),
    lr=LEARNING_RATE,
    momentum=MOMENTUM
)


# save_model
def save_checkpoint(checkpoint_path, model, optimizer):
    # state_dict: a Python dictionary object that:
    # - for a model, maps each layer to its parameter tensor;
    # - for an optimizer, contains info about the optimizer's states and hyperparameters used.
    state = {
        'model': model.state_dict(),
        'optimizer' : optimizer.state_dict()
    }
    torch.save(state, checkpoint_path)
    print 'model saved to ', checkpoint_path


# train
def mnist_train(epoch, save_interval):
    mnist_model.train()  # set training mode

    iteration = 0
    loss_plt = []

    for ep in range(epoch):
        for batch_idx, batch_data in enumerate(trainset_loader):
            images, labels = batch_data
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            output = mnist_model(images)

            loss = F.cross_entropy(output, labels)
            loss_plt.append(loss)

            loss.backward()
            optimizer.step()

            print 'Train Epoch:', ep+1, '\tBatch: ', batch_idx+1, '/', len(trainset_loader), '\tLoss: ', loss.item()

            # different from before: saving model checkpoints
            if iteration % save_interval == 0 and iteration > 0:
                save_checkpoint('module/pytorch-mnist-batchsize-128-%i.pth' % iteration, mnist_model, optimizer)

            iteration += 1

        mnist_test()

    # save the final model
    save_checkpoint('module/pytorch-mnist-batch-128-%i.pth' % iteration, mnist_model, optimizer)
    plt.plot(loss_plt, label='loss')
    plt.legend()
    plt.show()


# test
def mnist_test():
    mnist_model.eval()  # set evaluation mode

    test_loss = 0
    correct = 0

    with torch.no_grad():
        for images, labels in testset_loader:
            images = images.to(device)
            labels = labels.to(device)

            output = mnist_model(images)

            test_loss += F.cross_entropy(output, labels).item()

            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(labels.view_as(pred)).sum().item()


    test_loss /= len(testset_loader.dataset)

    print '\nTest set: Average loss:', test_loss, '\tAccuracy:', (100. * correct / len(testset_loader.dataset)), '%\n'


if __name__ == '__main__':
    mnist_train(EPOCH, save_interval=1000)


 

 

 

3. Materials

1、torch.optim 优化算法包

https://pytorch.org/docs/stable/optim.html

 

 

 

4. Details

1、OSError: [Errno 12] Cannot allocate memory

一开始以为是自己电脑配置(内存不够大)太低,每次 load 一个 batch 的图片数量不能太多,所以就一直在改 BATCH_SIZE 这个超参数。后面不停降低 BATCH_SIZE 还总报错,就意识到应该不是内存容量的问题。

后来查了一下,是加载数据(batch)的线程数目问题

https://blog.csdn.net/breeze210/article/details/99679048

 

2、需要自己先新建好 Module 文件夹

好吧,原来 Python 写文件的时候,如果路径中的文件夹不存在,是不会自动创建好的。Mark!

 

3、优化器中的 momentum 参数(待查阅更多有关优化器的资料)

凯神的解释:MOMENTUM 动量是随机梯度下降中用于更新模型权重的一个参数

https://www.lizenghai.com/archives/29512.html

https://pytorch.org/docs/stable/optim.html

 

4、model.to(device)

将所有最开始读取数据时的 tensor 变量 copy 一份到指定设备 device 上,之后的运算都在指定设备上进行。

https://www.jb51.net/article/178049.htm

 

5、Module.parameters()

https://blog.csdn.net/qq_39463274/article/details/105295272?utm_medium=distribute.pc_relevant_t0.none-task-blog-BlogCommendFromMachineLearnPai2-1.channel_param&depth_1-utm_source=distribute.pc_relevant_t0.none-task-blog-BlogCommendFromMachineLearnPai2-1.channel_param

 

6、if __name__ == "__main__"

https://blog.csdn.net/yjk13703623757/article/details/77918633

 

7、state_dict()

  • https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.state_dict
  • https://blog.csdn.net/VictoriaW/article/details/72821329?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-1.channel_param&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-1.channel_param

 

8、checkpoint

 a way to save the current state of your experiment so that you can pick up from where you left off.

https://www.cnblogs.com/Arborday/p/9740253.html

 

9、为什么要使用 optimizer.zero_grad()

因为后面反向传播时优化器会自动计算梯度,不要让上一次迭代的梯度影响到本次迭代的梯度

https://blog.csdn.net/scut_salmon/article/details/82414730

 

10、optimizer.step() 和 loss.backward() 的区别

最开始有点搞不清楚这两个函数分别是干什么的。后来看视频拿个类比,我就明白了

线性回归中,权值参数的公式为:w_new = w_old + lr * gradient

loss.backward() 就相当于计算 gradient 的

optimizer.step() 就相当于根据 gradient 计算 w_new = w_old + lr * gradient 的

https://v.qq.com/x/page/t0554h33liw.html

 

11、with torch.no_grad() 和 model.eval()

Use both. They do different things, and have different scopes.
with torch.no_grad: disables tracking of gradients in autograd.
model.eval(): changes the forward() behaviour of the module it is called upon. eg, it disables dropout and has batch norm use the entire population statistics

https://www.cnblogs.com/shiwanghualuo/p/11789018.html

https://blog.csdn.net/songyu0120/article/details/103884586?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-2.channel_param&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-2.channel_param

 

你可能感兴趣的:(Python,学习,人工智能,PyTorch,LeNet,MNIST,python)