Pytorch网络-3

定义网络

  1. 建一个class
  2. 定义网络的层
  3. 定义网络的前向传播,反向传播会自动推导。
    Pytorch网络-3_第1张图片

Dataloader

Pytorch网络-3_第2张图片
2.
Pytorch网络-3_第3张图片Pytorch网络-3_第4张图片

优化器&损失函数

Pytorch网络-3_第5张图片

训练过程

model.train() 指明这个网络有梯度,要更新参数。
optimizer.zero_grad() 优化器清零。
output = model(data) 计算前传得到预测的输出。
loss=F.nall_loss(output, target) 计算损失。
loss.backward() 反向传播计算梯度,梯度默认存在优化器里面。
optimizer.step() 利用优化器对网络参数进行更新。

Pytorch网络-3_第6张图片Pytorch网络-3_第7张图片model.eval() 代表网络没梯度不更新网络参数。
Pytorch网络-3_第8张图片

你可能感兴趣的:(深度学习)