F.cross_entropy-weight比较

import torch
import torch.nn.functional as F
from torch.autograd import Variable

x = Variable(torch.Tensor([[1.0,2.0,3.0], [1.0,2.0,3.0]]))
y = Variable(torch.LongTensor([1, 2]))
w = torch.Tensor([1.0,1.0,1.0])
res = F.cross_entropy(x,y,w)
# 0.9076
w = torch.Tensor([1.0,10.0,1.0])
res = F.cross_entropy(x,y,w)
# 1.3167
x = Variable(torch.Tensor([[1.0,2.0,3.0]]))
y = Variable(torch.LongTensor([1]))
w = torch.Tensor([1.0,1.0,1.0])
F.cross_entropy(x,y,w)
#1.4076
w = torch.Tensor([1.0,10.0,1.0])
F.cross_entropy(x,y,w)
#1.4076

 

你可能感兴趣的:(Pytorch)