PyTorch requires_grad/detach

前情提要

在排查GAN训练失真问题时,对pytorch中梯度相关知识进行了回顾,特此记录,以便自我回顾。

参考文章

  • 在pytorch中停止梯度流的若干办法,避免不必要模块的参数更新
    PyTorch requires_grad/detach_第1张图片

  • 一文搞透pytorch中的tensor、autograd、反向传播和计算图

  • pytorch训练GAN时的detach()

测试代码

import torch
import torch.nn as nn


def test_requires_grad(requires_grad=False):
    torch.manual_seed(0)
    x = torch.randn(2, 2)
    print('============ input ======== \n {} \n ========================='.format(x))
    # x.requires_grad = True

    lin0 = nn.Linear(2, 2)
    lin1 = nn.Linear(2, 2)
    lin2 = nn.Linear(2, 2)
    lin3 = nn.Linear(2, 2)
    x1 = lin0(x)
    x2 = lin1(x1)
    for p in lin2.parameters():
        print('is_leaf: {}'.format(p.is_leaf))
        p.requires_grad = requires_grad
    x3 = lin2(x2)
    x4 = lin3(x3)
    x4.sum().backward()
    print(lin0.weight.grad)
    print(lin1.weight.grad)
    print(lin2.weight.grad)
    print(lin3.weight.grad)

    print(x.grad_fn)
    print(x1.grad_fn)
    print(x2.grad_fn)
    print(x3.grad_fn)
    print(x4.grad_fn)


def test_detach(detach=False):
    torch.manual_seed(0)
    x = torch.randn(2, 2)
    print('============ input ======== \n {} \n ========================='.format(x))
    x.requires_grad = True
    print(x.is_leaf)
    lin0 = nn.Linear(2, 2)
    lin1 = nn.Linear(2, 2)
    lin2 = nn.Linear(2, 2)
    lin3 = nn.Linear(2, 2)
    x1 = lin0(x)
    x2 = lin1(x1)
    if detach:
        x3 = lin2(x2.detach())
    else:
        x3 = lin2(x2)
    x4 = lin3(x3)
    x4.sum().backward()
    print(lin0.weight.grad)
    print(lin1.weight.grad)
    print(lin2.weight.grad)
    print(lin3.weight.is_leaf, lin3.weight.grad)

    print(x.grad_fn)
    print(x1.grad_fn)
    print(x2.grad_fn)
    print(x3.grad_fn)
    print(x4.grad_fn)


if __name__ == '__main__':

    # test_detach(True)
    # test_detach(False)

    test_requires_grad(True)
    test_requires_grad(False)



你可能感兴趣的:(pytorch,深度学习,python)