pytorch学习笔记12-完整的模型训练套路

目录

  • 模型训练步骤
    • 各步骤如下
    • argmax用法
  • 完整模型训练
    • 代码
    • 输出结果(部分)
    • 小细节

模型训练步骤

各步骤如下

1.准备数据集
2.查看数据集大小
3.用dataloader加载数据集
4.创建网络模型(一般存为一个model文件,在其中搭建网络模型)
5.创建损失函数
6.创建优化器
7.设置训练网络参数(训练次数、测试次数、训练轮数)
8.开始训练
9.用tensorboard查看loss曲线
10.保存每训练100步的模型结果
11.引入准确率

argmax用法

argmax()返回最大值所在标签。其中的参数keepdim设为1时指水平比较的最大值,为0指竖直比较。
假设有一个二分类问题:
pytorch学习笔记12-完整的模型训练套路_第1张图片

import torch

outputs = torch.tensor([[0.1, 0.2],
                        [0.05, 0.4]])

print(outputs.argmax(1))

当outputs.argmax()中参数设为1时,是指水平看最里面的维度,返回最大值(0.2和0.4)所在标签,均为1。
结果:

tensor([1, 1])

argmax的练习总代码

import torch

n = 1
outputs = torch.tensor([[0.1, 0.2],
                        [0.05, 0.4]])

print(outputs.argmax(1))
print(outputs.argmax(0))

target = torch.tensor([0, 1])
predict = outputs.argmax(1)
print(predict == target)
print((predict == target).sum())

accuracy = (predict == target).sum() / n
print(accuracy)

结果:

tensor([1, 1])
tensor([0, 1])
tensor([False, True])
tensor(1)
tensor(1.)

完整模型训练

代码

from torch.utils.tensorboard import SummaryWriter

from model import *
import torchvision.datasets
from torch import nn
from torch.utils.data import DataLoader
# 准备数据集
train_data = torchvision.datasets.CIFAR10(root="P20_dataset", train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)
test_data = torchvision.datasets.CIFAR10(root="P20_dataset", train=False, transform=torchvision.transforms.ToTensor(),
                                         download=True)
writer = SummaryWriter("P27_train")
# 查看数据集大小
train_data_size = len(train_data)
test_data_size = len(test_data)

print("训练集的长度为:{}".format(train_data_size))
print("测试集的长度为:{}".format(test_data_size))

# 用dataloader加载数据集
train_dataloader = DataLoader(train_data, batch_size=64, drop_last=True)
test_dataloader = DataLoader(test_data, batch_size=64, drop_last=True)

# 创建网络模型
model = Model()

# 创建损失函数
loss_fn = nn.CrossEntropyLoss()

# 优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

# 设置训练网络的参数
# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 设置训练的轮数
epoch = 10

for i in range(epoch):
    print("--------------第 {} 轮训练开始---------------".format(i+1))

    # 训练步骤
    model.train()
    for data in train_dataloader:
        imgs, targets = data
        outputs = model(imgs)
        loss = loss_fn(outputs, targets)

        # 优化器优化模型
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_step += 1
        # 使用item()函数取出的元素值的精度更高,所以在求损失函数时一般用item()
        if total_train_step % 100 == 0:
            print("训练次数: {}, Loss: {} ".format(total_train_step, loss.item()))
            writer.add_scalar("train_loss", loss.item(), total_train_step)

    # 测试步骤开始
    model.eval()
    # 以测试集上的损失或者正确率来判断模型是否训练的好
    # 验证集与测试集不一样的,验证集是在训练中用的,反正模型过拟合,测试集是在模型完全训练好后使用的
    # 验证集用来调整超参数,相当于真题,测试集是考试
    total_test_loss = 0
    total_accuracy = 0
    # 只需要进行测试,不需要对梯度进行调整,所以设置下面这行
    with torch.no_grad():
        for data in test_dataloader:
            imgs, tartgets = data
            outputs = model(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss += loss.item()
            accuracy = (outputs.argmax(1) == targets).sum()     # 横向设为1
            total_accuracy += accuracy

    print("整体测试集上的Loss: {}".format(total_test_loss))
    print("整体测试集上的准确率:{}".format(total_accuracy/test_data_size))
    writer.add_scalar("test_loss", total_test_loss, total_test_step)
    writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)

    total_test_step += 1

    torch.save(model, "model_{}.pth".format(i))
    print("模型已保存")

writer.close()

输出结果(部分)

pytorch学习笔记12-完整的模型训练套路_第2张图片
pytorch学习笔记12-完整的模型训练套路_第3张图片

pytorch学习笔记12-完整的模型训练套路_第4张图片

小细节

代码当中在开始训练和开始测试部分都有一行代码

model.train()	# 设置模型为train模式
model.eval()	# 设置模型为eval模式

这两行代码并不是必须加上,如果说搭建的网络模型当中有用到一些特殊的层才有作用。如果没有这些特殊的层,添加这两行代码没有作用。
由于搭建的神经网络都是继承torch.nn.Module(),所以在官方文档可以相关的描述:
pytorch学习笔记12-完整的模型训练套路_第5张图片
(设置train和eval是因为bn层和dropout层在测试和训练时候是不一样的)

你可能感兴趣的:(pytorch,学习,深度学习)