RUAS代码debug

  • 最近训练Retinex-inspired Unrolling with Cooperative Prior Architecture Search for Low-light Image Enhancement的代码(https://github.com/KarelZhang/RUAS)的时候报错说梯度回传有问题,有变量被修改了,仔细一看源码,问题出在这里:
  • RUAS代码debug_第1张图片
  • model.py文件中,226行是报错位置,也就是说denoise_loss backward的时候,与denoise_loss有梯度回传关系的张量或者网络参数被修改了,导致torch检查的时候,denoise loss来自version 50的张量,回传到某个张量的时候,这个张量的版本已经变成version 51了,其实很好理解,因为每50次才会运行到一次226行,所以问题出在前面enhancement_loss的部分
  • 问题出在220行,220行先把enhance_net的参数更新掉了,而到226行梯度回传到enhance_net,此时enhance_net的版本已经不是产生denoise_loss的版本了。要改也很简单,因为从算法上,enhance_net和denoise_net其实并不是联合训练关系,梯度完全没有必要回传到enhance_net,只需要掐断就好了。那么梯度是怎么传到enhance_net的呢,其实是192行,enhance_net生成的u_list进一步传到了denoise_net,导致回传的时候通过u_list回传到了enhance_net,那么只需要在这里掐断就好了,加个detach:
    RUAS代码debug_第2张图片
  • 同时225行算loss的时候也用到了u_list,也掐断:
    RUAS代码debug_第3张图片
  • 改完就能正常训练了。注意这个改法是建立在明确知道denoise_net和enhance_net没有任何联合训练关系的前提,否则不建议随意掐断梯度。

你可能感兴趣的:(实用代码,计算机视觉,深度学习)