Gan训练思想

1、两组数据, 两个网络 D(鉴别器网络) G(生成器网络) , opti_D opti_G ,
训练 鉴别器网络,
(1)real data 输入Dmodel 得到预测值计算损失, lossDr , 预测值越大越好,
(2)Gmodel 生成的 fake data , 输入Dmodel 得到预测值计算损失 ,lossDf , 预测值越小越好 ,Gmodel.detach() 这里有一个梯度截断,不更新Gmodel(虽然用到了他)
更新 opti_D
目的是分开两种数据,真的就是真的,假的就是假的 看损失函数怎么写的,两者损失都要小

训练 生成器网络
(1)Gmodel 生成的 fake data , 输入Dmodel 得到预测值计算损失 ,lossDf , 预测值越大越好
(这里和上面的区别,产生对抗, 目的是为了生成逼近真实数据的生成数据)
这里反向梯度还是从 Dmodel 预测损失传过来,
更新的时候 opti_G

对应如下代码

Gan训练思想_第1张图片

Gan训练思想_第2张图片

你可能感兴趣的:(Gan训练思想)