利用BCELoss解决Multi-label问题

class torch.nn.BCELoss(weight=None, size_average=True)[source]

计算 target 与 output 之间的二进制交叉熵。 l o s s ( o , t ) = − 1 n ∑ i ( t [ i ] l o g ( o [ i ] ) + ( 1 − t [ i ] ) l o g ( 1 − o [ i ] ) ) loss(o,t)=-\frac{1}{n}\sum_i(t[i] log(o[i])+(1-t[i]) log(1-o[i])) loss(o,t)=n1i(t[i]log(o[i])+(1t[i])log(1o[i])) 如果weight被指定 : l o s s ( o , t ) = − 1 n ∑ i w e i g h t s [ i ] ( t [ i ] l o g ( o [ i ] ) + ( 1 − t [ i ] ) ∗ l o g ( 1 − o [ i ] ) ) loss(o,t)=-\frac{1}{n}\sum_iweights[i] (t[i] log(o[i])+(1-t[i])* log(1-o[i])) loss(o,t)=n1iweights[i](t[i]log(o[i])+(1t[i])log(1o[i]))

这个用于计算 auto-encoder 的 reconstruction error。注意 0<=target[i]<=1。

默认情况下,loss会基于element平均,如果size_average=False的话,loss会被累加。

train:

训练时候只需要利用one-hot编码对每一个example的lable进行组织,比如输入X具有4个属性,一共具有A,B,C和D这4个类别, X 1 X_1 X1同时属于A,D两类,则有如下符号表示形式:
X 1 = ( x 1 , x 2 , x 3 , x 4 , x 5 , x 6 , x 7 , x 8 , x 9 , x 10 ) , Y 1 = ( 1 , 0 , 0 , 1 ) X_1=(x1,x2,x3,x4,x5,x6,x7,x8,x9,x10),\\ Y_1=(1,0,0,1) X1=(x1,x2,x3,x4,x5,x6,x7,x8,x9,x10),Y1=(1,0,0,1)

validation:

由于每一个样本可能属于多个类别,因此不能直接根据argmax(output,1)去直接取概率最大的输出值作为其类别,而是应该设置一个阈值(默认0.5)来对output进行过滤,得到样本所属的multi-class。
以下为Pytorch实现多标签损失函数的样例代码:

import torch
import torch.nn as nn
import math
import torch.nn.functional as F

model = nn.Linear(20, 5) # predict logits for 5 classes
x = torch.FloatTensor(1, 20)
y = torch.FloatTensor([[1., 0., 1., 0., 0.]]) # get classA and classC as active
#criterion = nn.BCEWithLogitsLoss()
criterion = nn.BCELoss(weight=None, size_average=True)#input:FloatTensor target:FloatTensor
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
for epoch in range(2000):
    optimizer.zero_grad()
    var_x = torch.autograd.Variable(x)
    var_y = torch.autograd.Variable(y)
    output = model(var_x)
    output_sig = F.sigmoid(output)#注意对于BCELoss必须保证其输入的参数都位于0到1之间,无论是标签还是预测值,如果不想让预测值必须位于0到1之间则可以采用BCEWithLogitsLoss损失函数。
    loss = criterion(output_sig, var_y)
    loss.backward()
    optimizer.step()
    print('Loss:{}'.format(loss))

具体的BCELoss的使用方法可以参考如下博客:
https://blog.csdn.net/tmk_01/article/details/80844260?utm_source=blogxgwz0

你可能感兴趣的:(Pytorch学习)