GAN 中关于辨别器 detach()函数的作用

detach()函数的作用相信大家已经在网上搜过了,这里再简单叙述一下,照搬一下,

即返回一个新的tensor,从当前计算图中分离下来。但是仍指向原变量的存放位置,不同之处只是requirse_grad为false.得到的这个tensir永远不需要计算器梯度,不具有grad。简单来说detach就是截断反向传播的梯度流。

之前测试了一篇mnist-GAN的代码,代码很简单,最近要用到他。但是如下所示,辨别器D中有个detach()函数,

 D_optimizer.zero_grad() 
        fake = G(z)  
        
        d_fake_res = D(fake.detach()) 
        d_fake_loss = BCE_loss(d_fake_res,torch.zeros_like(d_fake_res)) 
        d_real_res = D(x)  # shape : torch.Size([128,1])
        d_real_loss = BCE_loss(d_real_res,torch.ones_like(d_real_res)) 
        d_loss = (d_fake_loss + d_real_loss)/2

        mean_D_loss += d_loss.item() / display_step
        d_loss.backward()
        D_optimizer.step() 

在结合上网查资料之后我所理解的如下:

用到detach()函数的原因是因为辨别器的损失函数loss是由两部分组成的,其中的一部分还与生成器有关。所以当执行d_loss.backward() 反向传播的时候,会把所有相关权重参数的梯度都给计算到,虽然辨别器的优化器参数给定的只是它自己的网络参数本身,但是这并不妨碍反向传播要计算有关反向传播计算图中的所有权重参数。这样会造成资源的浪费。所以在训练D时要用detach来截断反向传播流,使反向传播只计算完有关辨别器的所有权重参数就不再继续往下进行了。(注意,反向传播时辨别器是在第一位置,然后才传到生成器,即辨别器的输出→ 辨别器→生成器)。

然后训练生成器的时候,由于正向过程中生成器在辨别器之前,反向传播需要从辨别器的输出往回传,所以反向传播时的顺序是  辨别器的输出→ 辨别器→生成器,所以训练生成器时不需要再阻断反向传播的梯度流了。

希望能狗帮助到你,如果我的叙述有误,请留言评论探讨,谢谢帮我改正。加油(ง •_•)ง

你可能感兴趣的:(深度学习,神经网络,机器学习,python)