在构建好神经网络之后,我们需要对神经网络进行更新,更新的依据就是根据实际数据在网络上的表现,求出我们的期望target和实际上的值eval之间的差距,也就是loss,然后使用优化器更新网络使得loss变小。
1.损失函数
pycharm提供了很多现成的损失函数,但是在遇到实际问题时需要根据我们遇到的问题自己定义损失loss究竟是多少。
这里举一个例子:对于分类问题,pycharm专门提供了的交叉熵损失函数,这个函数就是专门求出分类问题的loss的,我们需要输入图片经过神经网络处理后属于每种类的可能性(eval)和这个图片真实属于的类,交叉熵损失函数会计算出loss,可见,当一张“狗”的图片经过神经网络,如果输出结果中它属于“狗”的可能性最大,并且其他的可能性都很小,那么这个loss就会很低。
2.优化器
优化器的主要作用就是根据损失函数求出的loss,对神经网络的参数进行更新,一般来说,更新都和梯度相关,比如SGD(Stochastic Gradient Descent),即根据梯度下降来更新参数,当然,pytorch也提供了很多的优化器,例如Momentum方法,它使得优化器更新时要考虑上一个时刻的状态,相当于给优化器加了“惯性”,使得它更容易直着走;AdaGrad方法,它对更新的速率进行了强行限制,相当于给算法穿了“鞋子”;大家听的很多的方法Adam方法则是结合了上面描述的两种方法。
3.实际代码实现
此处我们给出一段代码,进行了损失和优化,方便大家深入理解:
import torch.optim
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# 加载数据集
dataset = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=1)
# 定义神经网络
class Bao(nn.Module):
def __init__(self):
super(Bao, self).__init__()
self.model1 = Sequential(
Conv2d(3, 32, kernel_size=(5, 5), padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self, x):
x = self.model1(x)
return x
# 使用tensorboard方便对关心的数进行记录
writer = SummaryWriter("logs")
loss = nn.CrossEntropyLoss() # 定义交叉熵损失函数
bao = Bao()
optim = torch.optim.SGD(bao.parameters(), lr=0.01)
# 定义SGD优化器对神经网络bao的参数进行更新,学习率为0.01,注意,学习率不能太大,也不能太小,太大的话学习不稳定,太小的话学习太慢,一般来说学习率可以单独设置,在开始学习时很大,后面逐渐变小
for epoch in range(10):
epoch_loss = 0
for data in dataloader:
optim.zero_grad() # 将梯度清零,避免上一次循环的梯度影响下一次的计算
imgs, targets = data
outputs = bao(imgs)
result_loss = loss(outputs, targets) # 根据输出和目标计算出损失值
result_loss.backward() # 求出神经网络当中的参数的梯度,以方便之后对神经网络参数进行的更新
optim.step() # 使用优化器进行更新
epoch_loss = epoch_loss + result_loss # 得到每一个epoch的损失值方便之后观察
writer.add_scalar("epoch_loss_0.01", epoch_loss, epoch) # 使用tensorboard记录loss
print(epoch_loss)
writer.close()