PyTorch深度学习笔记(十六)优化器

课程学习笔记,课程链接

优化器:神经网络的学习的目的就是寻找合适的参数,使得损失函数的值尽可能小。解决这个问题的过程为称为最优化。解决这个问题使用的算法叫做优化器。在 PyTorch 官网中,将优化器放置在 torch.optim 中,并详细介绍了各种优化器的使用方法。

现以 CIFAR10 数据集为例,损失函数选取交叉熵函数,优化器选择 SGD 优化器,搭建神经网络,并计算其损失值,用优化器优化各个参数,使其朝梯度下降的方向调整。设置 epoch,让其执行 20 次,并将每一次完整的训练的损失函数值求和输出。

import torch
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoader
​
dataset = torchvision.datasets.CIFAR10("D:\Code\Project\learn_pytorch\pytorch_p17-21\data", train=False,
                                       download=True, transform=torchvision.transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=4)
​
class Jiaolong(nn.Module):
    def __init__(self):
        super(Jiaolong, self).__init__()
        self.model1 = Sequential(
            Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),
            MaxPool2d(kernel_size=2),
            Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2),
            MaxPool2d(kernel_size=2),
            Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),
            MaxPool2d(kernel_size=2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )
​
    def forward(self, x):
        x = self.model1(x)
        return x
​
loss = nn.CrossEntropyLoss()
jiaolong = Jiaolong()
# 构建 SGD 优化器,其中 jiaolong.parameters() 表示:待优化参数的 iterable 或者是定义了参数组的 dict,lr=0.01 表示学习率
optim = torch.optim.SGD(jiaolong.parameters(), lr=0.01)
for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        imgs, targets = data
        outputs = jiaolong(imgs)
        result_loss = loss(outputs, targets)
        # 将上一轮计算的梯度清零,避免上一轮的梯度值会影响下一轮的梯度值计算
        optim.zero_grad()
        # 反向传播过程,在反向传播过程中会计算每个参数的梯度值
        result_loss.backward()
        # 所有的 optimizer 都实现了 step() 方法,该方法会更新所有的参数。
        optim.step()
        running_loss = running_loss + result_loss
    print(running_loss)

PyTorch深度学习笔记(十六)优化器_第1张图片

你可能感兴趣的:(PyTorch,pytorch,深度学习,神经网络)