梯度翻转层GRL

参考梯度翻转层GRL - 云+社区 - 腾讯云

编码器和领域分类器的训练目标是相反的,我们可以使用对抗网络(Adversarial Networks)的模式来进行训练。而另一种更加简单的方法就是梯度反转了。

梯度反转层

我们来看下图。模型的输入经过编码器得到特征向量,随后 被送到两个网络中:(1) 标记分类器和 (2) 领域分类器 。标记分类器输出数据标记 ,而领域分类器则预测特征向量的来源的领域

梯度翻转层GRL_第1张图片

在上面,编码器 和领域分类器的训练目标是对抗的,因此文章在二者之间添加了一个梯度反转层(gradient reversal layer, GRL)。

众所周知,反向传播是指将损失(预测值和真实值的差距)逐层向后传递,然后每层网络都会根据传回来的误差计算梯度,进而更新本层网络的参数。而GRL所做的就是,就是将传到本层的误差乘以一个负数( ),这样就会使得GRL前后的网络其训练目标相反,以实现对抗的效果。

下面是在pytorch实现的代码。

class grl_func(torch.autograd.Function):
    def __init__(self):
        super(grl_func, self).__init__()

    @ staticmethod
    def forward(ctx, x, lambda_):
        ctx.save_for_backward(lambda_)
        return x.view_as(x)

    @ staticmethod
    def backward(ctx, grad_output):
        lambda_, = ctx.saved_variables
        grad_input = grad_output.clone()
        return - lambda_ * grad_input, None


class GRL(nn.Module):
    def __init__(self, lambda_=0.):
        super(GRL, self).__init__()
        self.lambda_ = torch.tensor(lambda_)

    def set_lambda(self, lambda_):
        self.lambda_ = torch.tensor(lambda_)

    def forward(self, x):
        return grl_func.apply(x, self.lambda_)

需要注意的是,并不是一个常数,而是由0变为1,即

                          梯度翻转层GRL_第2张图片

其中,是一个超参数,文章中设为10;随着训练的进行由0变为1,表示当前的训练步数/总的训练步数。上面的式子意味着一开始时,,领域分类损失不会回传到编码器网络中,只有领域分类器得到训练;随着训练的进行,逐渐增加,编码器得到训练,并开始逐步生成可以混淆领域分类器的特征。

梯度翻转层GRL_第3张图片

你可能感兴趣的:(机器学习理论,计算机视觉,深度学习,机器学习)