[笔记] PyTorch 踩坑记录

目录

  1. 关于 torch.clamp()
  2. 关于 tensor.scatter_()
  3. 关于 tensor.detach() 和 tensor.data
  4. 关于 tensor.register_hook
  5. 关于 tensor.narrow()
  6. 关于 torch.utils.data.DataLoader()
  7. 关于 torch.nn.register_parameter() 与 torch.nn.register_buffer()
  8. 关于 DataLoader worker (pid xxx) is killed by signal: Killed.
  9. 关于 torch.nn.Module.train() 和 torch.nn.Module.eval()

1. 关于 torch.clamp()

在我尝试实现 Focal Loss 的过程中,遇到 Loss 有时候变成 nan 的情况,询问师兄后得知:如果要对数值进行 log 操作,最好先对其进行 clamp 操作,防止其中存在极小值,导致计算结果出现 nan

prob = self.softmax(pred.view(-1,self.class_num))
prob.clamp(min=0.0001,max=1.0) # 等价于 prob = torch.clamp(prob,min=0.0001,max=1.0)

(tensor.clamp_() 为 in-place 函数,与其对应的 out-place 版本的为 tensor.clamp())

官方文档:

[笔记] PyTorch 踩坑记录_第1张图片

2. 关于 tensor.scatter_()

在实现 Focal Loss 的过程中,为了实现针对多分类任务的 Focal Loss,需要将模型所生成的预测向量转为 one-hot 向量,查询了一番便了解到了 scatter_。scatter_ 将 src 中数据根据 index 中的索引按照 dim 的方向填进input中,详细解释可以阅读这篇博文:Pytorch scatter_ 理解轴的含义。

target_ = torch.zeros(target.size(0),self.class_num).cuda()
target_.scatter_(1, target.view(-1, 1).long(), 1.)

(tensor.scatter_() 为 in-place 函数,与其对应的 out-place 版本的为 tensor.scatter())

官方文档:

[笔记] PyTorch 踩坑记录_第2张图片

3. 关于 tensor.detach()

关于 tensor.detach() 和 tensor,data 的区别,网上已经有许多讲解了,可以查看这些网页:Differences between .data and .detach,pytorch .detach() .detach_() 和 .data用于切断反向传播,PyTorch中 tensor.detach() 和 tensor.data 的区别。

但是,我觉得这篇博文讲的更直白一些,Pytorch入门学习(九)—detach()的作用(从GAN代码分析)。最近刚好在做 Domain Adaptation 相关方面的学习,在阅读 Conditional Adversarial Domain Adaptation 的代码时,发现作者也利用 tensor.detach() 进行梯度截断。该作者还在 issue 中解释了该操作的目的:

[笔记] PyTorch 踩坑记录_第3张图片

官方文档:

[笔记] PyTorch 踩坑记录_第4张图片

4. 关于 tensor.register_hook()

在 PyTorch 中,Hook 机制除了 tensor.register_hook() 外,还有 nn.Module.register_forward_hook() 和 nn.Module.register_backward_hook
(),相关的详细介绍可以阅读这些博文:pytorch学习笔记(七):pytorch hook 和 关于pytorch backward过程的理解,pytorch中的钩子(Hook)有何作用?。

在 Conditional Adversarial Domain Adaptation 的代码中,作者通过使用 tensor.register_hook(),使得梯度回传时,Domain Discriminator 和 ResNet 所得到的梯度值符号相反,以此达到对抗学习的目的。

[笔记] PyTorch 踩坑记录_第5张图片
def grl_hook(coeff):
    def fun1(grad):
        return -coeff*grad.clone()
    return fun1

def calc_coeff(iter_num, high=1.0, low=0.0, alpha=10.0, max_iter=10000.0):
    return np.float(2.0 * (high - low) / (1.0 + np.exp(-alpha*iter_num / max_iter)) - (high - low) + low)

class AdversarialNetwork(nn.Module):
  def __init__(self, in_feature, hidden_size):
    super(AdversarialNetwork, self).__init__()

    self.ad_layer1 = nn.Linear(in_feature, hidden_size)
    self.ad_layer2 = nn.Linear(hidden_size, hidden_size)
    self.ad_layer3 = nn.Linear(hidden_size, 1)

    self.relu1 = nn.ReLU()
    self.relu2 = nn.ReLU()

    self.dropout1 = nn.Dropout(0.5)
    self.dropout2 = nn.Dropout(0.5)

    self.sigmoid = nn.Sigmoid()

    self.apply(init_weights)
    self.iter_num = 0
    self.alpha = 10
    self.low = 0.0
    self.high = 1.0
    self.max_iter = 10000.0

  def forward(self, x):
    if self.training:
        self.iter_num += 1
    coeff = calc_coeff(self.iter_num, self.high, self.low, self.alpha, self.max_iter)
    x = x * 1.0
    x.register_hook(grl_hook(coeff))
    x = self.ad_layer1(x)
    x = self.relu1(x)
    x = self.dropout1(x)
    x = self.ad_layer2(x)
    x = self.relu2(x)
    x = self.dropout2(x)
    y = self.ad_layer3(x)
    y = self.sigmoid(y)

    return y

  def output_num(self):
    return 1

  def get_parameters(self):
    return [{"params":self.parameters(), "lr_mult":10, 'decay_mult':2}]

官方文档:

[笔记] PyTorch 踩坑记录_第6张图片

5. 关于 tensor.narrow()

tensor.narrow() 沿着所输入 tensor 的维度 dimension,从索引 start 开始,提取共计 length 个数据。

output.narrow(0, 0, data_source.size(0)) # 等价于 output = torch.narrow(0, 0, data_source.size(0))

官方文档:

[笔记] PyTorch 踩坑记录_第7张图片

6. 关于 torch.utils.data.DataLoader()

常见的使用 DataLoader 的代码:

data_loader = torch.utils.data.DataLoader(dataset=data_set, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=True)
for epoch in range(EPOCH):
    for step, (data, label) in enumerate(data_loader):   # gives batch data
         ......

但是在 Domain Adversarial 的代码中,经常需要对两个 DataLoader 同时进行迭代,我们可以这样写:


    len_source = len(train_source_dataloader)
    len_target = len(train_target_dataloader)
    if len_source > len_target:
        num_iter = len_source
    else:
        num_iter = len_target

    end = time.time()
    for batch_idx in range(num_iter):
        if batch_idx % len_source == 0:
            iter_source = iter(train_source_dataloader)    
        if batch_idx % len_target == 0:
            iter_target = iter(train_target_dataloader)
        
        data_source, label_source = iter_source.next()
        data_source, label_source = data_source.cuda(), label_source.cuda()
        data_target, label_target = iter_target.next()
        data_target = data_target.cuda()

其中,有这样一条语句:len_source = len(train_source_dataloader),这得到的是该 data_loader 的迭代次数,即len(data_loader) = math.ceil(len(data_loader.dataset)/batch_size)),详细解释可以查看About the relation between batch_size and length of data_loader。

7. 关于 torch.nn.register_parameter() 与 torch.nn.register_buffer()

torch.nn.register_parameter() 用于注册 Parameter 实例到当前 Module 中(一般可以用 torch.nn.Parameter() 代替),torch.nn.register_buffer() 用于注册 Buffer 实例到当前 Module 中。此外,Module 中的 buffers() 函数会返回当前 Module 中所注册的所有 Buffer 的迭代器,而 parameters() 函数会返回当前 Module 中所注册的所有 Parameter 的迭代器(所以优化器不会计算 Buffer 的梯度,自然不会对其更新)。此外,Module 中的 state_dict() 会返回包含当前 Module 中所注册的所有 Parameter 和 Buffer(所以模型中未注册成 Parameter 或 Buffer 的参数无法被保存)。(详细解释,可以看 What is the difference between ‘register_buffer’ and ‘register_parameter’ of ‘nn.Module’)

官方文档:

[笔记] PyTorch 踩坑记录_第8张图片 [笔记] PyTorch 踩坑记录_第9张图片

8.关于 DataLoader worker (pid xxx) is killed by signal: Killed

这条报错出现在我设置 DataLoader 使用多进程(即 num_worker > 0)且程序中出现 OOM(Out Of Memory,注意指的是内存,并非显存) 的时候,光看这条报错的信息并不能知道是由于 OOM 所导致的,我也是上网搜索才知道的。

9. 关于 torch.nn.Module.train() 和 torch.nn.Module.eval()

大家都知道在测试模型的时候需要调用 torch.nn.Module.eval(),不然无法固定模型中 Dropout 和 BN 模块,大家也知道 torch.nn.Module.eval() 等价于torch.nn.Module.train(mode=False),但是很少有人去想这两个函数到底是如何让 Module 设置为训练/测试模式的。

最近在做实验的时候,需要仿照 BN 的实现构造一个取平均值的模型,于是跑去看了看 torch.nn.Module 的源码。 torch.nn.Module 中有个属性 training ,该属性决定 Module 处于何种模式中。换句话说,train()/eval() 就是通过更改 training 的值来改变 Module 的状态的。

官方文档:

[笔记] PyTorch 踩坑记录_第10张图片

参考资料:

  • Pytorch scatter_ 理解轴的含义
  • Differences between .data and .detach
  • pytorch .detach() .detach_() 和 .data用于切断反向传播
  • PyTorch中 tensor.detach() 和 tensor.data 的区别
  • Pytorch入门学习(九)—detach()的作用(从GAN代码分析)
  • pytorch学习笔记(七):pytorch hook 和 关于pytorch backward过程的理解
  • pytorch中的钩子(Hook)有何作用?
  • Pytorch:RuntimeError: DataLoader worker (pid 27) is killed by signal: Killed. Details are lost due
  • What is the difference between ‘register_buffer’ and ‘register_parameter’ of ‘nn.Module’

如果你看到了这篇文章的最后,并且觉得有帮助的话,麻烦你花几秒钟时间点个赞,或者受累在评论中指出我的错误。谢谢!

作者信息:
知乎:没头脑
CSDN:Code_Mart
Github:Tao Pu

你可能感兴趣的:(PyTorch)