在上一篇的基础上,详细解释IRM代码中每一个函数的用法与作用。
import torch
from torch.autograd import grad
import numpy as np
import torchvision
def compute_penalty(losses, dummy_w):
# print(np.shape(losses[0::2]))
# print(dummy_w)
g1 = grad(losses[0::2].mean(), dummy_w, create_graph=True)[0]
g2 = grad(losses[1::2].mean(), dummy_w, create_graph=True)[0]
# print(g1*g2)
return (g1*g2).sum()
def example_1(n=10000, d=2, env=1):
x = torch.randn(n, d)*env
y = x + torch.randn(n, d)*env
z = y + torch.randn(n, d)
# z = y
# print(np.shape(torch.cat((x, z), 1))) # torch.Size([10000, 4])
return torch.cat((x, z), 1), y.sum(1, keepdim=True)
phi = torch.nn.Parameter(torch.ones(4, 1))
# print(phi)
dummy_w = torch.nn.Parameter(torch.Tensor([1.0]))
# print(dummy_w)
opt = torch.optim.SGD([phi], lr=1e-3)
mse = torch.nn.MSELoss(reduction="none")
environments = [example_1(env=0.1), example_1(env=1.0)]
# s = [[1, 2], [3, 4]]
# for i, j in s:
# print(i)
# print(j)
for iteration in range(50000):
error = 0
penalty = 0
for x_e, y_e in environments:
# print(np.shape(x_e))
# print(np.shape(y_e))
p = torch.randperm(len(x_e))
error_e = mse(x_e[p]@phi*dummy_w, y_e[p])
# error_e = mse(torch.matmul(x_e[p], phi) * dummy_w, y_e[p])
# print(np.shape(error_e))
penalty += compute_penalty(error_e, dummy_w)
error += error_e.mean()
# print(iteration)
# print(error_e.mean())
# print(error)
opt.zero_grad()
(1e-5 * error + penalty).backward()
opt.step()
if iteration % 1000 == 0:
print(phi)
torch.nn.Parameter是一种用于定义模型参数的Tensor,当其被指定为模型的属性时,它们会自动添加到参数列表中(PS:目前最新的API调用是torch.nn.parameter.Paramater)。其包含两个形参,分别为data和require_grad,data是所定义参数的值或者tensor,require_grad是设置该参数是否需要梯度,该形参可以提高计算效率。
phi = torch.nn.Parameter(torch.ones(4, 1))
dummy_w = torch.nn.Parameter(torch.Tensor([1.0]))
首先定义两个参数,第一个是 Φ \Phi Φ,其形状是[4,1],其原因是所定义的输入 X 1 X_1 X1和 X 2 X_2 X2的维度均为2,其次定义了 ω \omega ω,但是根据IRMv1,该参数是一个常数,因此这里将其设置为1.0
torch.optim.SGD是Pytorch自带的优化器,SGD是随机梯度下降法。 torch.optim.SGD有六个参数,
opt = torch.optim.SGD([phi], lr=1e-3)
该语句是定义了随机梯度下降优化器,优化参数是 Φ \Phi Φ,学习率是0.003
torch.nn.MSELoss是均方误差函数,其详细资料如下图
mse = torch.nn.MSELoss(reduction="none")
当reduction为none时,返回值为一个序列。
根据论文所提出的例子1所定义的数据,torch.randn生成数量为n,维度为d的服从正态分布的数据
def example_1(n=10000, d=2, env=1):
x = torch.randn(n, d)*env
y = x + torch.randn(n, d)*env
z = y + torch.randn(n, d)
# z = y
# print(np.shape(torch.cat((x, z), 1))) # torch.Size([10000, 4])
return torch.cat((x, z), 1), y.sum(1, keepdim=True)