【GAN网络】tensorflow和pytorch实现损失函数

tensorflow实现 

 主要注意以下两点:
1、log形式的损失函数输入的是经过判别器的概率D(x),tf.nn.sigmoid_cross_entropy_with_logits输入的是logits
2、log形式前面要添加负号,tf.nn.sigmoid_cross_entropy_with_logits不用。

import tensorflow as tf
#batch_size = 3,真实数据
real_logits=tf.constant([[1.25],
                   [2.5],
                   [-1.7]]) #GAN的话是两分类,因此最后只有一个节点,经过D映射后表示为真的概率   
real_prob=tf.nn.sigmoid(real_logits)     #真实数据经过D后被判别为真的概率
read_labels=tf.ones_like(real_logits)  #稠密标签,这里为1,该数据的是真实数据

real_D_loss = -tf.reduce_mean(tf.log(real_prob))  #以log形式计算损失,真实数据的损失为-log(D(x))
real_d_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=real_logits, 
                                                                   labels=read_labels))  #以tensorflow自带的形式计算,自动带负号
#batch_size = 3,生成数据
fake_logits=tf.constant([[0.22],
                   [10.8],
                   [-5]]) 
fake_prob=tf.nn.sigmoid(fake_logits)    
fake_labels=tf.zeros_like(fake_logits)   #生成数据的标签为0
fake_D_loss = -tf.reduce_mean(tf.log(1-fake_prob))  #以log形式计算损失,生成数据的损失为-log(1-D(x))
fake_d_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits, 
                                                                   labels=fake_labels))  #以tensorflow自带的形式计算,自动带负号

Total_loss = real_D_loss + fake_D_loss #log形式
total_loss = real_d_loss + fake_d_loss #tensorflow

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    D_losss, d_losss=sess.run([real_D_loss,real_d_loss])
    print(D_losss,d_losss)  #0.73286825 0.73286825 两者是一样的
    
    Total_losss, total_losss=sess.run([Total_loss,total_loss])
    print( Total_losss, total_losss) #两者也是一样的

pytorch实现

 主要注意以下两点:
1、nn.BCELoss函数计算的是每一个样本的损失,总损失还要对整个batch求平均,因此后面有mean()
2、前面不用加负号

import torch
import torch.nn as nn

m = nn.Sigmoid()
loss = nn.BCELoss(size_average=False, reduce=False)
#真实数据
real_logits = torch.tensor([[1.25],
                       [2.5],
                       [-1.7]])
real_prob = m(real_logits)
real_label = torch.ones(3,1)  
real_loss = loss(real_prob, real_label).mean()

#虚假数据
fake_logits=torch.tensor([[0.22],
                   [10.8],
                   [-5]]) 
fake_label = torch.zeros(3,1) 
fake_prob = m(fake_logits)
fake_loss = loss(fake_prob, fake_label).mean()

total_loss = real_loss + fake_loss


print("计算real_loss的结果:")
print(real_loss)
print("计算fake_loss的结果:")
print(fake_loss)
print("计算总损失:")
print(total_loss)

 

你可能感兴趣的:(tensorflow,GAN,pytorch,tensorflow,GAN,pytorch)