Pytorch中的torch.autograd.Function的使用

Pytorch中的torch.autograd.Function的使用_第1张图片

 案例:

import torch


class Line(torch.autograd.Function):
    """
    自定义前向和反向传播
    """

    @staticmethod
    def forward(ctx, w, x, b):
        """
        前向传播
        :param ctx: 上下文管理器
        :param w:
        :param x:
        :param b:
        :return:
        """
        ctx.save_for_backward(w, x, b)
        return w * x + b

    @staticmethod
    def backward(ctx, grad_out):
        """
        反向传播
        思路就是上一级梯度乘以当前梯度
        对 w * x + b 链式求导
        :param ctx:上下文管理器
        :param grad_out:上一级梯度
        :return:
        """
        # 获取在前向传播里面的变量信息
        w, x, b = ctx.saved_tensors
        # 对w求导
        # 上一级梯度乘以当前梯度
        grad_w = grad_out * x
        # 对x求导
        grad_x = grad_out * w
        # 对b求导
        grad_b = grad_out
        return grad_w, grad_x, grad_b


w = torch.rand(2,2, requires_grad=True)
x = torch.rand(2,2, requires_grad=True)
b = torch.rand(2,2, requires_grad=True)

# 使用apply调用相应的运算
out = Line.apply(w, x, b)
out.backward(torch.ones(2, 2))

print(w, x, b)
# 查看梯度值
print(w.grad, x.grad, b.grad)

你可能感兴趣的:(Pytorch系列,pytorch)