目前pytorch中的交叉熵损失函数主要分为以下三类,我们将其使用的要点以及场景做一下总结。
类型一:F.cross_entropy()与torch.nn.CrossEntropyLoss()
类型二:F.binary_cross_entropy_with_logits()与torch.nn.BCEWithLogitsLoss()
①当为标准的二分类时,网络的输出节点为1
②当为非互斥的多分类时,分类个数即为网络的输出节点数
类型三:F.binary_cross_entropy()与torch.nn.BCELoss()
网络的输出节点可以为2,此时概率必须由softmax进行映射。
②当为非互斥的多分类时,分类个数即为网络的输出节点数,此时概率必须由sigmoid进行映射
类型一:F.cross_entropy()与torch.nn.CrossEntropyLoss()
- 网络的输出节点为2,表示real和fake(类别1和类别2)
类型二:F.binary_cross_entropy_with_logits()与torch.nn.BCEWithLogitsLoss()
- 由于这两个函数自带sigmoid函数,要想完成二分类,网络的输出节点个数必须设置为1
类型三:F.binary_cross_entropy()与torch.nn.BCELoss(),以下两种情况都可以使用:
- 当网络输出的节点为2时,一个节点为real另一个节点为fake,那么必然要采用softmax将logits映射为概率(两个节点的概率和为1),此时该函数输入为onehot label + softmax prob,计算出的交叉熵损失与类型一结算结果相同。
- 当网络的输出节点为1时,也就是后面我们要讲的GAN的交叉熵损失的实现,那么则需要使用sigmoid函数来进行映射。
这里我们以网络输出节点为2为例,由于类型二要求网络的输出节点为1,因此暂时不纳入讨论,主要讨论类型和类型三。测试代码如下:
(网络输出节点为1的二分类就是目前GAN的实现方式,该方式下类型一的函数不可用,只能采用类型二和类型三,后面将会详细讨论)
softmax = torch.nn.Softmax()
logits = np.array([[0.7, -0.1],
[-1.587, -0.5907]])
classes = 2
label = torch.tensor([1, 1])
logits = torch.from_numpy(logits).float()
#F.cross_entropy
loss1 = F.cross_entropy(logits, label)
print(loss1)
#nn.CrossEntropyLoss()
criterion = nn.CrossEntropyLoss()
loss2 = criterion(logits, label)
print(loss2)
#可以看到,loss1是等于loss2的
prob = softmax(logits) #计算概率
one_hot_label = one_hot(label, classes)
#F.binary_cross_entropy
loss3 = F.binary_cross_entropy(prob, one_hot_label) #输入概率和one-hot
print(loss3)
#torch.nn.BCELoss()
adversarial_loss = torch.nn.BCELoss()
loss4 = adversarial_loss(prob, one_hot_label)
print(loss4)
#同理,loss3是等于loss4的
#手动实现二分类的交叉熵损失
shixian = -torch.mean(torch.sum(one_hot_label * torch.log(prob), axis = 1)) #手动实现
print(shixian)
此时网络输出时多节点,每一个节点代表一个类别。
类型一:F.cross_entropy()与torch.nn.CrossEntropyLoss()
- 可以用于多分类的互斥任务,输入非onehot label + logit。但是不能用于多分类多标签任务。因为这两个函数中自带的softmax将网络的每一个节点都当作时互斥的独立节点,每个节点的概率和为1,因为概率最大的那个节点的类别会被当为最终的预测类别
类型二:F.binary_cross_entropy_with_logits()与torch.nn.BCEWithLogitsLoss()
- 不能用于多分类的互斥任务,只能用于多分类的非互斥任务
类型三:F.binary_cross_entropy()与torch.nn.BCELoss()
- 与类型二一样,不能用于多分类的互斥任务,只能用于多分类的非互斥任务。
这里我们首先讨论下类型一和类型三,为什么类型三不能用于多分类的互斥任务,只能用于多分类多标签的分类任务?我们来看一段代码,这里有三个类别,两个样本。
softmax = torch.nn.Softmax()
logits = np.array([[0.7, -0.1, 0.2],
[-1.587, -0.5907, 0.3]])
classes = 3
label = torch.tensor([1, 2])
logits = torch.from_numpy(logits).float()
### F.cross_entropy
loss1 = F.cross_entropy(logits, label)
print(loss1)
### nn.CrossEntropyLoss()
criterion = nn.CrossEntropyLoss()
loss2 = criterion(logits, label)
print(loss2)
##loss1 = loss2
上面是采用类型一的两个函数计算而来,loss1 = loss2 = 0.9833
然后我们用类型三的函数来实现,同样将logit通过softmax映射为概率,运行后的结果可以看loss3 =loss4 = 0.5649,不等于类型一的函数的结果的。
prob_softmax = softmax(logits) #计算概率
one_hot_label = one_hot(label, classes)
## F.binary_cross_entropy
loss3 = F.binary_cross_entropy(prob_softmax, one_hot_label) #输入概率和one-hot
print(loss3)
## torch.nn.BCELoss()
adversarial_loss = torch.nn.BCELoss()
loss4 = adversarial_loss(prob_softmax, one_hot_label)
print(loss4)
最后我们再手动实现类型三的损失究竟是怎么得到的:
#手动实现
shixian = -torch.mean(one_hot_label * torch.log(prob_softmax) + (1-one_hot_label) * torch.log(1-prob_softmax))
print(shixian)
可以看出来,F.binary_cross_entropy()与torch.nn.BCELoss()是将网络的每个节点看作是一个二分类的节点来计算交叉熵损失的。
进一步来讨论下类型二和类型三的一致性,代码如下。由于类型二中函数自动将logit通过sigloid函数映射为概率,为了检验一致性性,我门也需要通过sigmoid计算类型三所需要的概率。
最后可以看到下面的输出均为0.6378
sigmoid = nn.Sigmoid()
prob_sig = sigmoid(logits) #计算概率
##类型二
##F.binary_cross_entropy_with_logits
loss5 = F.binary_cross_entropy_with_logits(logits, one_hot_label)
print(loss5)
##torch.nn.BCEWithLogitsLoss()
BCEWithLogitsLoss = torch.nn.BCEWithLogitsLoss()
loss6 = BCEWithLogitsLoss(logits, one_hot_label)
print(loss6)
##类型三
##F.binary_cross_entropy
loss7 = F.binary_cross_entropy(prob_sig, one_hot_label) #输入概率和one-hot
print(loss7)
## torch.nn.BCELoss()
adversarial_loss = torch.nn.BCELoss()
loss8 = adversarial_loss(prob_sig, one_hot_label)
print(loss8)
#手动实现
shixian = -torch.mean(one_hot_label * torch.log(prob_sig) + (1-one_hot_label) * torch.log(1-prob_sig))
print(shixian)
GAN中的判别器出的损失就是典型的最小化二分类的交叉熵损失。但是在实现上,与二分类网络不同。
正因为判别器的输出是一维,类型一的两个函数F.cross_entropy()与torch.nn.CrossEntropyLoss()是没有办法使用的,因为这两个函数要求输入是二维的,即分别在real和fake的logit。因此只能采用类型二或者类型三的函数。
很多GAN网络采用的二分类交叉熵损失函数如下:
#类型二:
adversarial_loss_2 = torch.nn.BCEWithLogitsLoss(logit,y)
#类型三:
adversarial_loss_3 = torch.nn.BCELoss(p,y)
前面我们讲到,类型二和类型三的函数都是将每一个节点视为一个二分类的节点,因此对于每一个给节点,其具体的表达式可以写为:
#类型二:
torch.nn.BCEWithLogitsLoss(logit,y) = - (ylog(sigmoid(logit)) + (1-y)log(1-sigmoid(logit)))
# 其中logit表示判断为real的logit
# y=1表示real
# y=0表示fake
#类型三:
torch.nn.BCELoss(p, y) = - (ylog(p) + (1-y)log(1-p))
# 其中p表示判断为real的概率
# y=1表示real
# y=0表示fake
判别器输出维度为1,输出logit,有两个样本,都为fake图像
logits = np.array([1.2, -0.5])
logits = torch.from_numpy(logits).float()
sigmoid = nn.Sigmoid()
prob_sig = sigmoid(logits) #计算概率
label = torch.tensor([1, 1]).float()
#类型二:
adversarial_loss_2 = torch.nn.BCEWithLogitsLoss()
loss_2 = adversarial_loss_2(logits, 1-label) #因为是fake,需要将y设置为0
print(loss_2)
#类型三:
adversarial_loss_3 = torch.nn.BCELoss()
loss_3 = adversarial_loss_3(prob_sig, 1-label) #因为是fake,需要将y设置为0
print(loss_3)
#输出均为0.9687
通过上述代码可以分析如下:
(1)当样本为fake时,网络输出其为real的logit:
(2)样本为real,网络输出其为real的logit:
GAN网络在更新判别器时,代码一般如下:
criterion = torch.nn.BCELoss()
real_out = D(real_img) # 将真实图片放入判别器中
d_loss_real = criterion(real_out, 1) # 真实样本的损失
fake_img = G(z) # 随机噪声放入生成网络中,生成一张假的图片
fake_out = D(fake_img) # 判别器判断假的图片,
d_loss_fake = criterion(fake_out, 0) # 生成样本的损失
d_loss = d_loss_real + d_loss_fake # 两个相加 就是标准的交叉熵损失
optimizer_D.zero_grad()
d_loss.backward()
optimizer_D.step()
前面判别器处的损失是最小化交叉熵损失:
min - (ylog(p) + (1-y)log(1-p))
那么生成器与之相反就是最大化交叉熵损失:
max - (ylog(p) + (1-y)log(1-p))
因为真实样本于与生成器无关,因此可以转变为min log(1-p)
max - ((1-y)log(1-p)) = min (1-y)log(1-p) = min log(1-p)
上述形式为饱和形式,转变为非饱和如下。
min -log(p)
可以看到上式子在形式上就是将fake图像当作real图像进行优化。
可以这么理解:生成器的作用的就是尽可能生成逼近与real的fake,由于判别器判断的结果p就是表示图像为real的概率,那么生成器就希望p越高越好。而在训练判别器时,判别器对real的优化就是让其p越高越好,即尽可能的区分real和fake。
因此在更新生成器时,fake处的损失与更新判别器在real处的损失在逻辑上是一致的。
criterion = torch.nn.BCELoss()
fake_img = G(z) # 随机噪声放入生成网络中,生成一张假的图片
fake_out = D(fake_img) # 判别器判断假的图片,
G_loss = criterion(fake_out, 1) # 假样本的损失
optimizer_G.zero_grad()
G_loss .backward()
optimizer_G.step()
在GAN网络中,由于输出网络只有一个节点,表示图像属于real的logit或者prob,因此一般使用类型二和类型三的损失函数。
两类函数的实现如下:
torch.nn.BCEWithLogitsLoss(logit,y) = - (ylog(sigmoid(logit)) + (1-y)log(1-sigmoid(logit))) torch.nn.BCELoss(p, y) = - (ylog(prob) + (1-y)log(1-prob))
因为上述实现:
- 在更新判别器时:real图像后面label为1,fake图像后面label为0。分别计算real和fake的损失相加。
- 在更新判别器时:与real图像无关,fake图像后面label为1,更新。