pytorch实现梯度反转层(Gradient Reversal Layer)

问题

在有些任务中,我们需要实现梯度反转层(Gradient Reversal Layer),目的是为了在梯度反向传播时,经过计算图某个节点之后梯度往反向更新(DANN网络中便需要GRL)。pytorch提供了Function用于实现这个方法,但是看网上的博客并没有详细的实现方法的用法。

实现方式

pytorch中的Function

pytorch自定义layer有两种方式:

  • 通过继承torch.nn.Module类来实现拓展。只需重新实现__init__forward函数。
  • 通过继承torch.autograd.Function,除了要实现__init__forward函数,还要实现backward函数(就是自定义求导规则)。
    方式一看着简单,但是当要定义自己的求导方式时,就要自己实现backward,也就是所谓的Extending torch.autograd

关于Function的学习可以参看这个博客:https://blog.csdn.net/qq_27825451/article/details/95189376

因为可以自定义求导的方式,所以我们使用Function实现GRL

实现代码

定义一些无关的类便于测试使用

from typing import Any, Optional, Tuple
from torch.autograd import Function
import torch.nn as nn
import torch
import torch.optim as optim
import torch.nn.functional as F
import random
import numpy

random.seed(0)
torch.manual_seed(0)
numpy.random.seed(0)

第一种实现方式

  1. 定义一个继承自FunctionGradientReverseFunction
class GradientReverseFunction(Function):
    """
    重写自定义的梯度计算方式
    """
    @staticmethod
    def forward(ctx: Any, input: torch.Tensor, coeff: Optional[float] = 1.) -> torch.Tensor:
        ctx.coeff = coeff
        output = input * 1.0
        return output

    @staticmethod
    def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, Any]:
        return grad_output.neg() * ctx.coeff, None
  1. 在需要反转的代码中使用GRF
class NormalClassifier(nn.Module):

    def __init__(self, num_features, num_classes, GRL=None):
        super().__init__()
        self.linear = nn.Linear(num_features, num_classes)
        if GRL:
            self.grl = GRL()

    def forward(self, x):
        if getattr(self, 'grl', None) is not None:
            x = GradientReverseFunction.apply(x)                # 注意这里
        return self.linear(x)

第二种实现方式

如果感觉刚才使用apply的应用方式不习惯,可以包装成一个层

  1. 把第一种方式中的GradientReverseFunction包装成GradientReverseLayer
class GRL(nn.Module):
    def __init__(self):
        super(GRL, self).__init__()

    def forward(self, *input):
        return GradientReverseFunction.apply(*input)
  1. 在需要反转的代码中使用GRF
class NormalClassifier(nn.Module):

    def __init__(self, num_features, num_classes, GRL=None):
        super().__init__()
        self.linear = nn.Linear(num_features, num_classes)
        if GRL:
            self.grl = GRL()

    def forward(self, x):
        if getattr(self, 'grl', None) is not None:
            x = self.grl(x)                # 注意这里
        return self.linear(x)

看下结果

测试代码

if __name__ == '__main__':
    net1 = NormalClassifier(3, 6)
    net2 = NormalClassifier(6, 10, GRL=None)            # 不使用反转层
    # net2 = NormalClassifier(6, 10, GRL=GRL)           # 使用反转层
    net3 = NormalClassifier(10, 2)

    data = torch.rand((4, 3))
    label = torch.ones((4), dtype=torch.long)
    out = net3(net2(net1(data)))
    loss = F.cross_entropy(out, label)
    loss.backward()

    print('net1.linear.weight.grad', net1.linear.weight.grad)
    print('net2.linear.weight.grad', net1.linear.weight.grad)
    print('net3.linear.weight.grad', net1.linear.weight.grad)

结果

# 1.这是没有使用GRL
net1.linear.weight.grad tensor([[-0.0027, -0.0044, -0.0026],
        [-0.0420, -0.0675, -0.0400],
        [-0.0030, -0.0048, -0.0029],
        [ 0.0035,  0.0056,  0.0033],
        [-0.0336, -0.0540, -0.0320],
        [-0.0454, -0.0729, -0.0432]])
net2.linear.weight.grad tensor([[ 0.0027,  0.0034, -0.0032, -0.0028,  0.0044, -0.0049],
        [ 0.0452,  0.0577, -0.0544, -0.0473,  0.0747, -0.0830],
        [-0.0897, -0.1146,  0.1081,  0.0939, -0.1483,  0.1647],
        [-0.0702, -0.0897,  0.0846,  0.0735, -0.1161,  0.1290],
        [ 0.0519,  0.0663, -0.0626, -0.0543,  0.0859, -0.0954],
        [ 0.0520,  0.0664, -0.0627, -0.0544,  0.0860, -0.0955],
        [-0.0967, -0.1235,  0.1166,  0.1012, -0.1599,  0.1776],
        [-0.0058, -0.0074,  0.0069,  0.0060, -0.0095,  0.0106],
        [-0.0124, -0.0158,  0.0149,  0.0129, -0.0204,  0.0227],
        [ 0.0830,  0.1060, -0.1000, -0.0869,  0.1373, -0.1525]])
net3.linear.weight.grad tensor([[ 0.1127, -0.2764, -0.0864, -0.1450,  0.2694, -0.1738, -0.1415,  0.3108,
          0.0458, -0.1464],
        [-0.1127,  0.2764,  0.0864,  0.1450, -0.2694,  0.1738,  0.1415, -0.3108,
         -0.0458,  0.1464]])

# 2.这是使用了GRL
net1.linear.weight.grad tensor([[ 0.0027,  0.0044,  0.0026],
        [ 0.0420,  0.0675,  0.0400],
        [ 0.0030,  0.0048,  0.0029],
        [-0.0035, -0.0056, -0.0033],
        [ 0.0336,  0.0540,  0.0320],
        [ 0.0454,  0.0729,  0.0432]])
net2.linear.weight.grad tensor([[ 0.0027,  0.0034, -0.0032, -0.0028,  0.0044, -0.0049],
        [ 0.0452,  0.0577, -0.0544, -0.0473,  0.0747, -0.0830],
        [-0.0897, -0.1146,  0.1081,  0.0939, -0.1483,  0.1647],
        [-0.0702, -0.0897,  0.0846,  0.0735, -0.1161,  0.1290],
        [ 0.0519,  0.0663, -0.0626, -0.0543,  0.0859, -0.0954],
        [ 0.0520,  0.0664, -0.0627, -0.0544,  0.0860, -0.0955],
        [-0.0967, -0.1235,  0.1166,  0.1012, -0.1599,  0.1776],
        [-0.0058, -0.0074,  0.0069,  0.0060, -0.0095,  0.0106],
        [-0.0124, -0.0158,  0.0149,  0.0129, -0.0204,  0.0227],
        [ 0.0830,  0.1060, -0.1000, -0.0869,  0.1373, -0.1525]])
net3.linear.weight.grad tensor([[ 0.1127, -0.2764, -0.0864, -0.1450,  0.2694, -0.1738, -0.1415,  0.3108,
          0.0458, -0.1464],
        [-0.1127,  0.2764,  0.0864,  0.1450, -0.2694,  0.1738,  0.1415, -0.3108,
         -0.0458,  0.1464]])

分析

上面网络结构正向数据流向为:
net1 --> GRL --> net2–> net3

上面网络结构反向数据流向为:
net3 --> net2 --> GRL–> net1

通过输出结果可以看出来,net1的梯度反转了

你可能感兴趣的:(炼丹,pytorch,GRL)