在有些任务中,我们需要实现梯度反转层(Gradient Reversal Layer),目的是为了在梯度反向传播时,经过计算图某个节点之后梯度往反向更新(DANN网络中便需要GRL)。pytorch提供了Function用于实现这个方法,但是看网上的博客并没有详细的实现方法的用法。
pytorch自定义layer有两种方式:
torch.nn.Module
类来实现拓展。只需重新实现__init__
和forward
函数。torch.autograd.Function
,除了要实现__init__
和forward
函数,还要实现backward
函数(就是自定义求导规则)。关于Function
的学习可以参看这个博客:https://blog.csdn.net/qq_27825451/article/details/95189376
因为可以自定义求导的方式,所以我们使用Function实现GRL
定义一些无关的类便于测试使用
from typing import Any, Optional, Tuple
from torch.autograd import Function
import torch.nn as nn
import torch
import torch.optim as optim
import torch.nn.functional as F
import random
import numpy
random.seed(0)
torch.manual_seed(0)
numpy.random.seed(0)
Function
的GradientReverseFunction
class GradientReverseFunction(Function):
"""
重写自定义的梯度计算方式
"""
@staticmethod
def forward(ctx: Any, input: torch.Tensor, coeff: Optional[float] = 1.) -> torch.Tensor:
ctx.coeff = coeff
output = input * 1.0
return output
@staticmethod
def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, Any]:
return grad_output.neg() * ctx.coeff, None
class NormalClassifier(nn.Module):
def __init__(self, num_features, num_classes, GRL=None):
super().__init__()
self.linear = nn.Linear(num_features, num_classes)
if GRL:
self.grl = GRL()
def forward(self, x):
if getattr(self, 'grl', None) is not None:
x = GradientReverseFunction.apply(x) # 注意这里
return self.linear(x)
如果感觉刚才使用apply的应用方式不习惯,可以包装成一个层
GradientReverseFunction
包装成GradientReverseLayer
class GRL(nn.Module):
def __init__(self):
super(GRL, self).__init__()
def forward(self, *input):
return GradientReverseFunction.apply(*input)
class NormalClassifier(nn.Module):
def __init__(self, num_features, num_classes, GRL=None):
super().__init__()
self.linear = nn.Linear(num_features, num_classes)
if GRL:
self.grl = GRL()
def forward(self, x):
if getattr(self, 'grl', None) is not None:
x = self.grl(x) # 注意这里
return self.linear(x)
测试代码:
if __name__ == '__main__':
net1 = NormalClassifier(3, 6)
net2 = NormalClassifier(6, 10, GRL=None) # 不使用反转层
# net2 = NormalClassifier(6, 10, GRL=GRL) # 使用反转层
net3 = NormalClassifier(10, 2)
data = torch.rand((4, 3))
label = torch.ones((4), dtype=torch.long)
out = net3(net2(net1(data)))
loss = F.cross_entropy(out, label)
loss.backward()
print('net1.linear.weight.grad', net1.linear.weight.grad)
print('net2.linear.weight.grad', net1.linear.weight.grad)
print('net3.linear.weight.grad', net1.linear.weight.grad)
结果:
# 1.这是没有使用GRL
net1.linear.weight.grad tensor([[-0.0027, -0.0044, -0.0026],
[-0.0420, -0.0675, -0.0400],
[-0.0030, -0.0048, -0.0029],
[ 0.0035, 0.0056, 0.0033],
[-0.0336, -0.0540, -0.0320],
[-0.0454, -0.0729, -0.0432]])
net2.linear.weight.grad tensor([[ 0.0027, 0.0034, -0.0032, -0.0028, 0.0044, -0.0049],
[ 0.0452, 0.0577, -0.0544, -0.0473, 0.0747, -0.0830],
[-0.0897, -0.1146, 0.1081, 0.0939, -0.1483, 0.1647],
[-0.0702, -0.0897, 0.0846, 0.0735, -0.1161, 0.1290],
[ 0.0519, 0.0663, -0.0626, -0.0543, 0.0859, -0.0954],
[ 0.0520, 0.0664, -0.0627, -0.0544, 0.0860, -0.0955],
[-0.0967, -0.1235, 0.1166, 0.1012, -0.1599, 0.1776],
[-0.0058, -0.0074, 0.0069, 0.0060, -0.0095, 0.0106],
[-0.0124, -0.0158, 0.0149, 0.0129, -0.0204, 0.0227],
[ 0.0830, 0.1060, -0.1000, -0.0869, 0.1373, -0.1525]])
net3.linear.weight.grad tensor([[ 0.1127, -0.2764, -0.0864, -0.1450, 0.2694, -0.1738, -0.1415, 0.3108,
0.0458, -0.1464],
[-0.1127, 0.2764, 0.0864, 0.1450, -0.2694, 0.1738, 0.1415, -0.3108,
-0.0458, 0.1464]])
# 2.这是使用了GRL
net1.linear.weight.grad tensor([[ 0.0027, 0.0044, 0.0026],
[ 0.0420, 0.0675, 0.0400],
[ 0.0030, 0.0048, 0.0029],
[-0.0035, -0.0056, -0.0033],
[ 0.0336, 0.0540, 0.0320],
[ 0.0454, 0.0729, 0.0432]])
net2.linear.weight.grad tensor([[ 0.0027, 0.0034, -0.0032, -0.0028, 0.0044, -0.0049],
[ 0.0452, 0.0577, -0.0544, -0.0473, 0.0747, -0.0830],
[-0.0897, -0.1146, 0.1081, 0.0939, -0.1483, 0.1647],
[-0.0702, -0.0897, 0.0846, 0.0735, -0.1161, 0.1290],
[ 0.0519, 0.0663, -0.0626, -0.0543, 0.0859, -0.0954],
[ 0.0520, 0.0664, -0.0627, -0.0544, 0.0860, -0.0955],
[-0.0967, -0.1235, 0.1166, 0.1012, -0.1599, 0.1776],
[-0.0058, -0.0074, 0.0069, 0.0060, -0.0095, 0.0106],
[-0.0124, -0.0158, 0.0149, 0.0129, -0.0204, 0.0227],
[ 0.0830, 0.1060, -0.1000, -0.0869, 0.1373, -0.1525]])
net3.linear.weight.grad tensor([[ 0.1127, -0.2764, -0.0864, -0.1450, 0.2694, -0.1738, -0.1415, 0.3108,
0.0458, -0.1464],
[-0.1127, 0.2764, 0.0864, 0.1450, -0.2694, 0.1738, 0.1415, -0.3108,
-0.0458, 0.1464]])
分析:
上面网络结构正向数据流向为:
net1 --> GRL --> net2–> net3
上面网络结构反向数据流向为:
net3 --> net2 --> GRL–> net1
通过输出结果可以看出来,net1的梯度反转了