Pytorch 梯度反转层及测试

Pytorch 梯度反转层及测试

参考文献:
梯度反转

import torch
import torch.nn as nn
from torch.autograd.function import Function


class Grl_func(Function):
    def __init__(self):
        super(Grl_func, self).__init__()

    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.save_for_backward(lambda_)
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        lambda_, = ctx.saved_tensors
        grad_input = grad_output.clone()
        return -lambda_ * grad_input, None


class GRL(nn.Module):
    def __init__(self, lambda_=0.):
        super(GRL, self).__init__()
        self.lambda_ = torch.tensor(lambda_)

    def set_lambda(self, lambda_):
        self.lambda_ = torch.tensor(lambda_)

    def forward(self, x):
        return Grl_func.apply(x, self.lambda_)


# 首先建立一个全连接的子module,继承nn.Module
class Linear1(nn.Module):
    def __init__(self):
        super(Linear1, self).__init__()  # 调用nn.Module构造函数
        # 使用nn.Parameter来构造需要学习的参数
        self.w = nn.Parameter(torch.tensor([[1., 2., 3.], [1., 1., 1.]]))
        self.b = nn.Parameter(torch.tensor([1., 1., 1.]))

    # 在forward中实现向前传播过程
    def forward(self, x):
        x = x.matmul(self.w)  # 使用Tensor.matmul实现矩阵相乘
        y = x + self.b.expand_as(x)  # 使用Tensor.expand_as()来保证矩阵形状一致
        return y


# 首先建立一个全连接的子module,继承nn.Module
class Linear2(nn.Module):
    def __init__(self):
        super(Linear2, self).__init__()  # 调用nn.Module构造函数
        # 使用nn.Parameter来构造需要学习的参数
        self.w = nn.Parameter(torch.tensor([[1., 2., 3.], [1., 1., 1.], [1., 1., 1.]]))
        self.b = nn.Parameter(torch.tensor([1., 1., 1.]))

    # 在forward中实现向前传播过程
    def forward(self, x):
        x = x.matmul(self.w)  # 使用Tensor.matmul实现矩阵相乘
        y = x + self.b.expand_as(x)  # 使用Tensor.expand_as()来保证矩阵形状一致
        return y


# 实例化一个网络,并赋值全连接中的维数,最终输出二维代表了二分类
perception1 = Linear1()
perception2 = Linear2()
grl = GRL()
grl.set_lambda(1.0)
# 随机生成数据,注意这里的4代表了样本数为4,每个样本有两维
data = torch.tensor([[2., 1.], [1., 1.]])

output = perception1(data)
# output = grl(output)  # 是有效的
output = perception2(output)
output = grl(output)  # 是有效的

print(f'output:\n {output}\n')

output.sum().backward()

print(f'perception1.w.grad:\n {perception1.w.grad}\n')
print(f'perception1.b.grad:\n {perception1.b.grad}\n')

print(f'perception2.w.grad:\n {perception2.w.grad}\n')
print(f'perception2.b.grad:\n {perception2.b.grad}\n')

你可能感兴趣的:(重要,pytorch,深度学习)