在上一篇文章PyTorch-Tutorials【pytorch官方教程中英文详解】- 6 Autograd介绍了torch.autograd,接下来看看在模型的训练和优化中如何选取损失函数和优化器。
原文链接:Optimizing Model Parameters — PyTorch Tutorials 1.10.1+cu102 documentation
Now that we have a model and data it’s time to train, validate and test our model by optimizing its parameters on our data. Training a model is an iterative process; in each iteration (called an epoch) the model makes a guess about the output, calculates the error in its guess (loss), collects the derivatives of the error with respect to its parameters (as we saw in the previous section), and optimizes these parameters using gradient descent. For a more detailed walkthrough of this process, check out this video on backpropagation from 3Blue1Brown.
【现在我们有了一个模型和数据,是时候通过在数据上优化其参数来训练、验证和测试我们的模型了。训练一个模型是一个迭代的过程;在每次迭代(称为epoch)中,模型对输出进行猜测,计算猜测中的误差(损失),收集误差对其参数的导数(正如我们在前一节中看到的),并使用梯度下降优化这些参数。关于这个过程的更详细的演练,请查看来自backpropagation from 3Blue1Brown。】
We load the code from the previous sections on Datasets & DataLoaders and Build Model.
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
training_data = datasets.FashionMNIST(
test_data = datasets.FashionMNIST(
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.Linear(512, 512),
nn.Linear(512, 10),
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork()
Hyperparameters are adjustable parameters that let you control the model optimization process. Different hyperparameter values can impact model training and convergence rates (read more about hyperparameter tuning)
We define the following hyperparameters for training:
learning_rate = 1e-3
batch_size = 64
epochs = 5
Once we set our hyperparameters, we can then train and optimize our model with an optimization loop. Each iteration of the optimization loop is called an epoch.
Each epoch consists of two main parts:
Let’s briefly familiarize ourselves with some of the concepts used in the training loop. Jump ahead to see the Full Implementation of the optimization loop.
When presented with some training data, our untrained network is likely not to give the correct answer. Loss function measures the degree of dissimilarity of obtained result to the target value, and it is the loss function that we want to minimize during training. To calculate the loss we make a prediction using the inputs of our given data sample and compare it against the true data label value.
【当出现一些训练数据时,我们未经训练的网络很可能不会给出正确的答案。Loss function度量得到的结果与目标值的不相似程度,是我们在训练过程中希望最小化的Loss function。为了计算损失,我们使用给定数据样本的输入进行预测,并将其与真实的数据标签值进行比较。】
Common loss functions include nn.MSELoss (Mean Square Error) for regression tasks, and nn.NLLLoss (Negative Log Likelihood) for classification. nn.CrossEntropyLoss combines nn.LogSoftmax and nn.NLLLoss.
【常用的损失函数包括nn.MSELoss(Mean Square Error)用于回归;nn.NLLLoss(负对数似然)用于分类。nn.CrossEntropyLoss结合了nn.LogSoftmax 和 nn.NLLLoss。】
We pass our model’s output logits to nn.CrossEntropyLoss, which will normalize the logits and compute the prediction error.
# Initialize the loss function
loss_fn = nn.CrossEntropyLoss()
Optimization is the process of adjusting model parameters to reduce model error in each training step. Optimization algorithms define how this process is performed (in this example we use Stochastic Gradient Descent). All optimization logic is encapsulated in the optimizer object. Here, we use the SGD optimizer; additionally, there are many different optimizers available in PyTorch such as ADAM and RMSProp, that work better for different kinds of models and data.
We initialize the optimizer by registering the model’s parameters that need to be trained, and passing in the learning rate hyperparameter.
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
Inside the training loop, optimization happens in three steps:
We define train_loop that loops over our optimization code, and test_loop that evaluates the model’s performance against our test data.
def train_loop(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
for batch, (X, y) in enumerate(dataloader):
# Compute prediction and loss
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test_loop(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
We initialize the loss function and optimizer, and pass it to train_loop and test_loop. Feel free to increase the number of epochs to track the model’s improving performance.
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
epochs = 10
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train_loop(train_dataloader, model, loss_fn, optimizer)
test_loop(test_dataloader, model, loss_fn)
Epoch 1
loss: 2.290156 [ 0/60000]
loss: 2.275099 [ 6400/60000]
loss: 2.256799 [12800/60000]
loss: 2.252760 [19200/60000]
loss: 2.235528 [25600/60000]
loss: 2.205756 [32000/60000]
loss: 2.204928 [38400/60000]
loss: 2.172354 [44800/60000]
loss: 2.160271 [51200/60000]
loss: 2.127511 [57600/60000]
Test Error:
Accuracy: 49.9%, Avg loss: 2.116347
Epoch 2
loss: 2.124757 [ 0/60000]
loss: 2.107859 [ 6400/60000]
loss: 2.045332 [12800/60000]
loss: 2.061512 [19200/60000]
loss: 2.002954 [25600/60000]
loss: 1.940844 [32000/60000]
loss: 1.962774 [38400/60000]
loss: 1.874285 [44800/60000]
loss: 1.875532 [51200/60000]
loss: 1.802694 [57600/60000]
Test Error:
Accuracy: 58.7%, Avg loss: 1.794751
Epoch 3
loss: 1.830118 [ 0/60000]
loss: 1.797928 [ 6400/60000]
loss: 1.670504 [12800/60000]
loss: 1.718298 [19200/60000]
loss: 1.605203 [25600/60000]
loss: 1.560042 [32000/60000]
loss: 1.583883 [38400/60000]
loss: 1.483568 [44800/60000]
loss: 1.515428 [51200/60000]
loss: 1.414553 [57600/60000]
Test Error:
Accuracy: 62.0%, Avg loss: 1.430290
Epoch 4
loss: 1.499763 [ 0/60000]
loss: 1.472005 [ 6400/60000]
loss: 1.319050 [12800/60000]
loss: 1.399100 [19200/60000]
loss: 1.283040 [25600/60000]
loss: 1.279892 [32000/60000]
loss: 1.300507 [38400/60000]
loss: 1.221794 [44800/60000]
loss: 1.262865 [51200/60000]
loss: 1.173478 [57600/60000]
Test Error:
Accuracy: 63.9%, Avg loss: 1.193923
Epoch 5
loss: 1.268049 [ 0/60000]
loss: 1.260393 [ 6400/60000]
loss: 1.092561 [12800/60000]
loss: 1.205449 [19200/60000]
loss: 1.083632 [25600/60000]
loss: 1.101792 [32000/60000]
loss: 1.134809 [38400/60000]
loss: 1.062815 [44800/60000]
loss: 1.108174 [51200/60000]
loss: 1.035161 [57600/60000]
Test Error:
Accuracy: 65.1%, Avg loss: 1.049588
Epoch 6
loss: 1.114492 [ 0/60000]
loss: 1.130664 [ 6400/60000]
loss: 0.944653 [12800/60000]
loss: 1.083935 [19200/60000]
loss: 0.961972 [25600/60000]
loss: 0.981254 [32000/60000]
loss: 1.033072 [38400/60000]
loss: 0.961604 [44800/60000]
loss: 1.007507 [51200/60000]
loss: 0.948494 [57600/60000]
Test Error:
Accuracy: 66.0%, Avg loss: 0.956025
Epoch 7
loss: 1.006542 [ 0/60000]
loss: 1.046684 [ 6400/60000]
loss: 0.842564 [12800/60000]
loss: 1.002121 [19200/60000]
loss: 0.884486 [25600/60000]
loss: 0.895794 [32000/60000]
loss: 0.965427 [38400/60000]
loss: 0.895181 [44800/60000]
loss: 0.937755 [51200/60000]
loss: 0.889426 [57600/60000]
Test Error:
Accuracy: 67.3%, Avg loss: 0.891673
Epoch 8
loss: 0.926312 [ 0/60000]
loss: 0.987333 [ 6400/60000]
loss: 0.768049 [12800/60000]
loss: 0.943189 [19200/60000]
loss: 0.831892 [25600/60000]
loss: 0.833098 [32000/60000]
loss: 0.916814 [38400/60000]
loss: 0.850216 [44800/60000]
loss: 0.887719 [51200/60000]
loss: 0.846100 [57600/60000]
Test Error:
Accuracy: 68.5%, Avg loss: 0.844885
Epoch 9
loss: 0.864126 [ 0/60000]
loss: 0.941802 [ 6400/60000]
loss: 0.711602 [12800/60000]
loss: 0.898299 [19200/60000]
loss: 0.793915 [25600/60000]
loss: 0.786041 [32000/60000]
loss: 0.879356 [38400/60000]
loss: 0.818412 [44800/60000]
loss: 0.850554 [51200/60000]
loss: 0.812724 [57600/60000]
Test Error:
Accuracy: 69.7%, Avg loss: 0.809041
Epoch 10
loss: 0.814177 [ 0/60000]
loss: 0.904296 [ 6400/60000]
loss: 0.667563 [12800/60000]
loss: 0.862825 [19200/60000]
loss: 0.764706 [25600/60000]
loss: 0.750034 [32000/60000]
loss: 0.848550 [38400/60000]
loss: 0.794559 [44800/60000]
loss: 0.821466 [51200/60000]
loss: 0.785530 [57600/60000]
Test Error:
Accuracy: 70.9%, Avg loss: 0.780144