pytorch | 记录一次register_hook不起作用以及为什么

0. 问题描述

  1. register_hook用于给某个tensor注册hooks,
    • 这个函数需要传入一个钩子函数,而且钩子函数的input是loss.backward()执行后的grad(不能获取weight值)
  2. 笔者这个时候loss不收敛,debug发现梯度为0,因此通过加钩子,试图发现在传播时哪里出了问题。因此发现了register_hook并不是100%能work,而且不是100%的可以打印出grad

1. 探究过程

  1. 钩子函数:
def save_grad(name):
    print("****")		# 这行可以说明有没有执行钩子函数
    def hook(grad):
        print(f"name={name}, grad={grad}")
    return hook
  1. 注册过程
# U_head是loss function中的一个中间tensor,需要计算梯度
U_head.register_hook(save_grad("U_head"))

2. 不work的原因

  1. 来自Stack Overflow的建议:

    • ‘register_hook’ won’t only in two cases:
      • It was registered on a Tensor for which the gradients was never computed./ 梯度没计算
      • The register_hook function is some part of your code that did not run during the forward. / 正向传播没执行这条语句
  2. 因此结合了自己的代码,修正了如下几个bug,然后就work了

    • 最关键的修改:
        for t in idx:
        # work了
       	 	U_tail= torch.cat([U_head, f_Equi_ts(t, T,v, d, b, alpha, labda,device)])
    	# 这样写不work,莫非属于原地修改?
    	# U_tail = torch.tensor([f_Equi_ts(t, T,v, d, b, alpha, labda,device) for t in idx],requires_grad=True)
    
    • 对于 需要loss function的代码重写,尽可能简单,易读性高,先不追求效率

你可能感兴趣的:(pytorch,神经网络,python,pytorch,python,深度学习)