[pytorch学习笔记] 7. 优化模型参数,模型保存和加载

目录

优化模型参数

超参数

优化循环

损失函数

优化器

模型保存和加载

保存和加载模型权重

使用形状保存和加载模型


优化模型参数

现在我们有了模型和数据,是时候通过优化我们的数据上的参数来训练、验证和测试我们的模型了。 训练模型是一个迭代过程; 在每次迭代(称为 epoch)中,模型对输出进行猜测,计算猜测中的误差(损失),收集误差相对于其参数的导数,并优化 这些参数使用梯度下降。

We load the code from the previous sections on Datasets & DataLoaders and Build Model.

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork()
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

超参数

超参数是可调整的参数,可让您控制模型优化过程。 不同的超参数值会影响模型训练和收敛速度(阅读有关超参数调整的更多信息)

我们为训练定义了以下超参数:

  1. Number of Epochs - 迭代数据集的次数
  2. Batch Size - 参数更新前通过网络传播的数据样本数
  3. Learning Rate - 在每个批次/时期更新模型参数的程度。 较小的值会产生较慢的学习速度,而较大的值可能会导致训练期间出现不可预测的行为。
learning_rate = 1e-3
batch_size = 64
epochs = 5

优化循环

一旦我们设置了超参数,我们就可以使用优化循环来训练和优化我们的模型。 优化循环的每次迭代称为一个时期。

每个时期包括两个主要部分:
训练循环 - 迭代训练数据集并尝试收敛到最佳参数。
验证/测试循环 - 迭代测试数据集以检查模型性能是否正在改善。

损失函数

当呈现一些训练数据时,我们未经训练的网络可能不会给出正确的答案。 损失函数衡量得到的结果与目标值的相异程度,是我们在训练时要最小化的损失函数。 为了计算损失,我们使用给定数据样本的输入进行预测,并将其与真实数据标签值进行比较。

常见的损失函数包括用于回归任务的 nn.MSELoss(均方误差)和用于分类的 nn.NLLLoss(负对数似然)。 nn.CrossEntropyLoss 结合了 nn.LogSoftmax 和 nn.NLLLoss。

我们将模型的输出 logits 传递给 nn.CrossEntropyLoss,它将对 logits 进行归一化并计算预测误差。

# Initialize the loss function
loss_fn = nn.CrossEntropyLoss()

优化器


优化是在每个训练步骤中调整模型参数以减少模型误差的过程。 优化算法定义了如何执行这个过程(在这个例子中,我们使用随机梯度下降)。 所有优化逻辑都封装在优化器对象中。 在这里,我们使用 SGD 优化器; 此外,PyTorch 中有许多不同的优化器可用,例如 ADAM 和 RMSProp,它们可以更好地用于不同类型的模型和数据。

我们通过注册需要训练的模型参数并传入学习率超参数来初始化优化器。

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

在训练循环中,优化分三个步骤进行:

  1. 调用 optimizer.zero_grad() 来重置模型参数的梯度。 默认情况下渐变加起来; 为了防止重复计算,我们在每次迭代时明确地将它们归零。
  2. 通过调用 loss.backward() 反向传播预测损失。 PyTorch 存储损失 w.r.t 的梯度。 每个参数。
  3. 一旦我们有了梯度,我们调用 optimizer.step() 来通过反向传播中收集的梯度来调整参数。

定义了循环优化代码的 train_loop,以及针对我们的测试数据评估模型性能的 test_loop。

def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

我们初始化损失函数和优化器,并将其传递给 train_loop 和 test_loop。 随意增加 epoch 的数量来跟踪模型的改进性能。

out:
Epoch 1
-------------------------------
loss: 2.297689  [    0/60000]
loss: 2.290237  [ 6400/60000]
loss: 2.269128  [12800/60000]
loss: 2.264103  [19200/60000]
loss: 2.251986  [25600/60000]
loss: 2.220639  [32000/60000]
loss: 2.226866  [38400/60000]
loss: 2.195838  [44800/60000]
loss: 2.189956  [51200/60000]
loss: 2.159722  [57600/60000]
Test Error:
 Accuracy: 45.1%, Avg loss: 2.150723

Epoch 2
-------------------------------
loss: 2.153910  [    0/60000]
loss: 2.157291  [ 6400/60000]
loss: 2.091198  [12800/60000]
loss: 2.110021  [19200/60000]
loss: 2.076051  [25600/60000]
loss: 2.003284  [32000/60000]
loss: 2.033937  [38400/60000]
loss: 1.952907  [44800/60000]
loss: 1.953560  [51200/60000]
loss: 1.899702  [57600/60000]
Test Error:
 Accuracy: 58.9%, Avg loss: 1.883251

Epoch 3
-------------------------------
loss: 1.904147  [    0/60000]
loss: 1.898742  [ 6400/60000]
loss: 1.761854  [12800/60000]
loss: 1.808874  [19200/60000]
loss: 1.730858  [25600/60000]
loss: 1.654819  [32000/60000]
loss: 1.684215  [38400/60000]
loss: 1.573531  [44800/60000]
loss: 1.599872  [51200/60000]
loss: 1.518306  [57600/60000]
Test Error:
 Accuracy: 61.6%, Avg loss: 1.516425

Epoch 4
-------------------------------
loss: 1.570204  [    0/60000]
loss: 1.558830  [ 6400/60000]
loss: 1.387072  [12800/60000]
loss: 1.471276  [19200/60000]
loss: 1.372889  [25600/60000]
loss: 1.342269  [32000/60000]
loss: 1.366177  [38400/60000]
loss: 1.274934  [44800/60000]
loss: 1.318055  [51200/60000]
loss: 1.236622  [57600/60000]
Test Error:
 Accuracy: 63.6%, Avg loss: 1.250036

Epoch 5
-------------------------------
loss: 1.316210  [    0/60000]
loss: 1.317752  [ 6400/60000]
loss: 1.136089  [12800/60000]
loss: 1.248003  [19200/60000]
loss: 1.140282  [25600/60000]
loss: 1.146560  [32000/60000]
loss: 1.170238  [38400/60000]
loss: 1.094260  [44800/60000]
loss: 1.142456  [51200/60000]
loss: 1.074888  [57600/60000]
Test Error:
 Accuracy: 64.8%, Avg loss: 1.085233

Epoch 6
-------------------------------
loss: 1.147802  [    0/60000]
loss: 1.167118  [ 6400/60000]
loss: 0.971598  [12800/60000]
loss: 1.106905  [19200/60000]
loss: 0.997195  [25600/60000]
loss: 1.015868  [32000/60000]
loss: 1.050142  [38400/60000]
loss: 0.980582  [44800/60000]
loss: 1.028825  [51200/60000]
loss: 0.974603  [57600/60000]
Test Error:
 Accuracy: 65.8%, Avg loss: 0.979109

Epoch 7
-------------------------------
loss: 1.031183  [    0/60000]
loss: 1.070277  [ 6400/60000]
loss: 0.859174  [12800/60000]
loss: 1.013059  [19200/60000]
loss: 0.907712  [25600/60000]
loss: 0.923672  [32000/60000]
loss: 0.971847  [38400/60000]
loss: 0.907493  [44800/60000]
loss: 0.951341  [51200/60000]
loss: 0.907584  [57600/60000]
Test Error:
 Accuracy: 67.1%, Avg loss: 0.907024

Epoch 8
-------------------------------
loss: 0.945793  [    0/60000]
loss: 1.003288  [ 6400/60000]
loss: 0.779033  [12800/60000]
loss: 0.947461  [19200/60000]
loss: 0.847874  [25600/60000]
loss: 0.856052  [32000/60000]
loss: 0.916784  [38400/60000]
loss: 0.859343  [44800/60000]
loss: 0.896248  [51200/60000]
loss: 0.859036  [57600/60000]
Test Error:
 Accuracy: 68.3%, Avg loss: 0.855187

Epoch 9
-------------------------------
loss: 0.880296  [    0/60000]
loss: 0.952778  [ 6400/60000]
loss: 0.719199  [12800/60000]
loss: 0.899208  [19200/60000]
loss: 0.805123  [25600/60000]
loss: 0.804924  [32000/60000]
loss: 0.875258  [38400/60000]
loss: 0.826057  [44800/60000]
loss: 0.855472  [51200/60000]
loss: 0.821711  [57600/60000]
Test Error:
 Accuracy: 69.4%, Avg loss: 0.815939

Epoch 10
-------------------------------
loss: 0.827977  [    0/60000]
loss: 0.911847  [ 6400/60000]
loss: 0.672483  [12800/60000]
loss: 0.862297  [19200/60000]
loss: 0.772755  [25600/60000]
loss: 0.765456  [32000/60000]
loss: 0.841791  [38400/60000]
loss: 0.801990  [44800/60000]
loss: 0.824129  [51200/60000]
loss: 0.791598  [57600/60000]
Test Error:
 Accuracy: 70.9%, Avg loss: 0.784744

Done!

模型保存和加载

import torch
import torchvision.models as models

保存和加载模型权重

PyTorch 模型将学习到的参数存储在称为 state_dict 的内部状态字典中。 这些可以通过 torch.save 方法持久化:

model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')

要加载模型权重,您需要先创建相同模型的实例,然后使用 load_state_dict() 方法加载参数。

model = models.vgg16() # we do not specify pretrained=True, i.e. do not load default weights
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()

一定要在推理之前调用 model.eval()方法,将 dropout 和 batch normalization 层设置为评估模式。 不这样做会产生不一致的推理结果。

使用形状保存和加载模型

加载模型权重时,我们需要先实例化模型类,因为该类定义了网络的结构。 我们可能希望将此类的结构与模型一起保存,在这种情况下,我们可以将模型(而不是 model.state_dict())传递给保存函数:

torch.save(model, 'model.pth')
model = torch.load('model.pth')

参考:

官方文档:Save and Load the Model — PyTorch Tutorials 1.11.0+cu102 documentation

你可能感兴趣的:(pytorch,pytorch)