【pytorch】详解vanilla(BCEWithLogitsLoss )、lsgan(nn.MSELoss)

有如下代码块

       self.gan_mode = gan_mode
        if gan_mode == 'lsgan':
            self.loss = nn.MSELoss()
        elif gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode in ['wgangp']:
            self.loss = None

vanilla 即为 BCEWithLogitsLoss 

参见:https://blog.csdn.net/qq_22210253/article/details/85222093


lsgan 在pytorch中为 nn.MSELoss(),均方误差

公式:

MSELOSS = \frac{1}{2n}\sum||y_{n}-x_{n}||^{2}

 其中y为target,x为模型输出值

示例:

import torch
import torch.nn as nn

output = torch.rand(2,2)
print(output)

 tensor([[0.1234, 0.8351],
        [0.9274, 0.8286]])

target = torch.FloatTensor([[0,1],[1,0]])
print(target)

tensor([[0., 1.],
        [1., 0.]])

利用nn.MSELoss()计算损失:

crit = nn.MSELoss()
cost = crit(input,target)

输出结果为:

利用公式手工验证计算:(注意除的数为2n,即为4)

MSELoss_handle = ((0-0.1234)*(0-0.1234) + (1-0.8351)*(1-0.8351))/4 + ((1-0.9274)*(1-0.9274) + (0-0.8286)*(0-0.8286))/4

输出结果为:

四舍五入之后结果一致!


wgangp 参见

https://blog.csdn.net/weixin_37993251/article/details/87120269

你可能感兴趣的:(python,深度学习)