torch 的 F.cross_entropy

torch中的交叉熵损失函数使用案例

import torch
import torch.nn.functional as F

pred = torch.randn(3, 5)
print(pred.shape)

target = torch.tensor([2, 3, 4]).long() # 需要是整数
print(target.shape)

# 交叉熵损失函数, 输入的参数是形状不一样的
# predict会在其内部进行softmax操作
loss = F.cross_entropy(pred, target)
loss.item()

结果为:

torch 的 F.cross_entropy_第1张图片

需要注意的是, 传入的参数形状是不同的, predict是softmax之前的, 另外y需要是整形的, int也行

你可能感兴趣的:(torch,torch,cross_entropy)