二、pytorch核心概念:3.动态计算图

本博客是阅读eat pytorch in 20 day第二章的个人笔记

动态计算图

计算图由节点和边组成,节点是张量和函数,边表示依赖关系。动态的含义是,前向传播时每一步会立即得到计算结果,反向传播后计算图会立即销毁。

function同时包含正向计算和反向传播的逻辑,比如relu函数:

class MyReLU(torch.autograd.Function):
   
    #正向传播逻辑,可以用ctx存储一些值,供反向传播使用。
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    #反向传播逻辑
    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input
relu = MyReLU.apply

只有叶子节点的梯度才会被存到.grad属性里,其余节点的梯度只在计算中出现,并不保存。
.retain_grad()非叶子节点梯度保存、register_hook非叶子节点梯度显示。

import torch 

#正向传播
x = torch.tensor(3.0,requires_grad=True)
y1 = x + 1
y2 = 2*x
loss = (y1-y2)**2

#非叶子节点梯度显示控制
y1.register_hook(lambda grad: print('y1 grad: ', grad))
y2.register_hook(lambda grad: print('y2 grad: ', grad))
loss.retain_grad()

#反向传播
loss.backward()
print("loss.grad:", loss.grad)
print("x.grad:", x.grad)

计算图在TensorBoard中的可视化

你可能感兴趣的:(pytorch,python,人工智能,pytorch)