[五]深度学习Pytorch-计算图与动态图机制

0. 往期内容

[一]深度学习Pytorch-张量定义与张量创建

[二]深度学习Pytorch-张量的操作:拼接、切分、索引和变换

[三]深度学习Pytorch-张量数学运算

[四]深度学习Pytorch-线性回归

[五]深度学习Pytorch-计算图与动态图机制

[六]深度学习Pytorch-autograd与逻辑回归

[七]深度学习Pytorch-DataLoader与Dataset(含人民币二分类实战)

[八]深度学习Pytorch-图像预处理transforms

[九]深度学习Pytorch-transforms图像增强(剪裁、翻转、旋转)

[十]深度学习Pytorch-transforms图像操作及自定义方法

深度学习Pytorch-计算图与动态图机制

  • 0. 往期内容
  • 1. 计算图定义
  • 2. 计算图与梯度求导
  • 3. 动态图

1. 计算图定义

[五]深度学习Pytorch-计算图与动态图机制_第1张图片

2. 计算图与梯度求导

[五]深度学习Pytorch-计算图与动态图机制_第2张图片
[五]深度学习Pytorch-计算图与动态图机制_第3张图片
(1)注意叶子节点的梯度会保存,非叶子节点的梯度会释放掉
(2)代码示例
(2-1)非叶子节点的梯度被释放:

import torch

#创建w,x两个节点
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)

#创建边运算
a = torch.add(w, x)    

b = torch.add(w, 1)
y = torch.mul(a, b)

# 反向传播,使用自动求导系统反向传播即可得到梯度
y.backward()
print(w.grad)

# 查看叶子结点
#输出True True False False False
print("is_leaf:\n", w.is_leaf, x.is_leaf, a.is_leaf, b.is_leaf, y.is_leaf)

# 查看梯度
#输出[5.] [2.] None None None
print("gradient:\n", w.grad, x.grad, a.grad, b.grad, y.grad)

[五]深度学习Pytorch-计算图与动态图机制_第4张图片

(2-2)使用retain_grad()保存非叶子节点的梯度:

import torch

#创建w,x两个节点
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)

#创建边运算
a = torch.add(w, x)    

#如果采用a.retain_grad()则a的梯度可以保存
a.retain_grad()

b = torch.add(w, 1)
y = torch.mul(a, b)

# 反向传播,使用自动求导系统反向传播即可得到梯度
y.backward()
print(w.grad)

# 查看叶子结点
#输出True True False False False
print("is_leaf:\n", w.is_leaf, x.is_leaf, a.is_leaf, b.is_leaf, y.is_leaf)

# 查看梯度
#使用a.retain_grad()则输出为[5.] [2.] [2.] None None
print("gradient:\n", w.grad, x.grad, a.grad, b.grad, y.grad)

[五]深度学习Pytorch-计算图与动态图机制_第5张图片


grad_fn用来记录创建张量时所使用的方法(函数):
[五]深度学习Pytorch-计算图与动态图机制_第6张图片代码示例

import torch

#创建w,x两个节点
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)

#创建边运算
a = torch.add(w, x)    

#如果采用a.retain_grad()则a的梯度可以保存
#a.retain_grad()

b = torch.add(w, 1)
y = torch.mul(a, b)

# 反向传播,使用自动求导系统反向传播即可得到梯度
y.backward()

# 查看 grad_fn
#输出 None None AddBackward0 AddBackward0 MulBackward0
# w与x是None是因为他俩不是运算得到的
print("grad_fn:\n", w.grad_fn, x.grad_fn, a.grad_fn, b.grad_fn, y.grad_fn)

在这里插入图片描述

3. 动态图

[五]深度学习Pytorch-计算图与动态图机制_第7张图片[五]深度学习Pytorch-计算图与动态图机制_第8张图片
[五]深度学习Pytorch-计算图与动态图机制_第9张图片

你可能感兴趣的:(深度学习Pyrotch,pytorch,人工智能,深度学习,机器学习,python)