理解pytorch中的grad机制,以及计算图概念

import numpy as np
import torch
from torch.autograd import grad
def matmul(x, w):
    y = torch.pow(x, 3) * w
    return y


def net_y_xxx(x, w):  # first-order derivative of u
    y = matmul(x, w)  # 即使不增加require_grad_(True),也是可以求导
    y_x = grad(y.sum(), x, create_graph=True)[0]
    # 为了求解高阶导数,必须使用create_graph,不然y_x的requires_grad属性为False,并且计算完这个导数以后就会释放上一层的计算图
    y_xx = grad(y_x.sum(), x, create_graph=True)[0]
    y_xxx = grad(y_xx.sum(), x)[0]

    y.sum().backward()
    y_x.sum().backward()
    # y_xx.sum().backward()  # 由于计算y_xxx时候就释放了计算图,所以会报错

    return y, y_x, y_xx, y_xxx


if __name__ == '__main__':
    x = np.array([[1.0],
                  [2.0],
                  [3.0]])

    w = np.array([[1.0],
                  [2.0],
                  [3.0]])
    # requires_grad_(True)
    x = torch.from_numpy(x).requires_grad_(True)
    # x = torch.from_numpy(x)
    # w = torch.from_numpy(w).requires_grad_(True)
    w = torch.from_numpy(w)

    y, y_x, y_xx, y_xxx = net_y_xxx(x, w)

求y对x的三阶导,一共生成了3张计算图!!!

你可能感兴趣的:(pytorch)