bceloss前面需要经过sigmoid
BCEWithLogitsLoss就是把Sigmoid-BCELoss合成一步
input = torch.tensor(np.arange(3)/3).reshape(3,1)
input2 = torch.nn.Sigmoid()(input)
loss = torch.nn.BCELoss()
target = torch.tensor([[0],[1],[1]]).to(torch.double)
print(loss(input2,target))
loss = torch.nn.BCEWithLogitsLoss()
print (loss(input,target))
#输出
#tensor(0.5493, dtype=torch.float64)
#tensor(0.5493, dtype=torch.float64)
和上面一样,CrossEntropyLoss就是把以上Softmax–Log–NLLLoss合并成一步,注意这里还需要多经过一层log
input = torch.randn(3,3)
input2 = torch.log(torch.nn.Softmax(dim=1)(input))
loss = torch.nn.NLLLoss()
target = torch.tensor([0,2,1])
print(loss(input2,target))
loss = torch.nn.CrossEntropyLoss()
print (loss(input,target))
#输出
#tensor(1.0774)
#tensor(1.0774)
这里就只讨论多分类的
这个函数对应pytorch的cross_entropy
在tensorflow里的"logits"指的其实是,该方法是在logit数值上使用softmax或者sigmoid来进行normalization的,也暗示用户不要将网络输出进行sigmoid或者softmax,这些过程可以在函数内部更高效地计算。
在torch里:
outputs=torch.tensor([[[1.4,1.2],[0.0,0.0],[0.3,1.8],[0.3,0.0]],[[1.4,1.2],[0.0,0.0],[0.3,1.8],[0.3,0.0]]]).to(torch.float32)
labels =torch.tensor([[0,0],[0,0]])
loss = torch.nn.CrossEntropyLoss()(outputs, labels)
loss,outputs.shape,labels.shape
#输出
#(tensor(0.9396), torch.Size([2, 4, 2]), torch.Size([2, 2]))
在tensorflow/keras里,tf里的sparse_softmax_cross_entropy_with_logits相当于Keras的sparse_categorical_crossentropy(y_true,y_pred,from_logits=True)
import tensorflow as tf
y_pred = outputs.permute(0,2,1)
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=y_pred)
K.mean(loss),loss
#输出
"""
(,
)
"""
loss = K.sparse_categorical_crossentropy(
labels, y_pred,from_logits=True)
K.mean(loss),loss
#输出
"""
(,
)
"""
另外,keras版本默认from_logits=False
注意这里的from_logits=False表示外面先经过了一层softmax,维度为2则取axis=1,维度为3则取axis=2,log仍然是在sparse_categorical_crossentropy函数里面
y_pred = (np.arange(20)/15).reshape(1,10,2)
y_pred = np.transpose(y_pred,(0,2,1))
y_pred2 = K.softmax(y_pred,2)
y_true =np.array([[2,9]])
K.mean(K.sparse_categorical_crossentropy(y_true, y_pred2)),K.mean(K.sparse_categorical_crossentropy(y_true, y_pred,from_logits=True))
#输出
#(,)
主要是sparse代表tagets 是 数字编码,而不加sparse则是one_hot编码
def categorical_crossentropy(y_true, y_pred):
# y_true需要重新明确一下shape和dtype
y_pred = K.cast(y_pred, 'float32')
y_true = K.reshape(y_true, K.shape(y_pred)[:-1])
y_true = K.cast(y_true, 'int32')
y_true = K.one_hot(y_true, K.shape(y_pred)[2])
return K.mean(K.categorical_crossentropy(y_true, y_pred))
categorical_crossentropy(y_true, y_pred2)
#输出
#