PyTorch 踩坑实录 (1) - 损失函数

今天开始更新学习 FaceBook 的深度学习框架 PyTorch !

PyTorch 底层优化的非常好,而且与 Numpy 无缝对接,用起来很清爽,不像 TensorFlow 那么“反 Python”~

先看了 Deep Learning with PyTorch: A 60 Minute Blitz ,题目说是“一小时搞定”,但就我这个上了岁数的人来讲,花了一晚上才把一整套流程跑了一遍。。。

  • 交叉熵损失函数

接下来看看踩得第一个坑,在使用交叉熵损失函数 (Cross Entropy) 时抛出异常:

RuntimeError: multi-target not supported at …\aten\src\THNN/generic/ClassNLLCriterion.c:20

程序运行的经过为:

cross_entropy = nn.CrossEntropyLoss()
loss = cross_entropy(predicts, labels)

导致的原因为,函数的参数

predictsshapen*mnbatch_data 的样本数,m 为 模型输出层的维数;

labelsshape1*n,是一个一维 Tensor(在四分类任务中,形如 [1, 0, 3, 1, 2]),其中每个元素为类别的类别号(例如 一个样本的类别为飞机,其对应的类别标签为 1,则对应的 label 则为 1

  • 平方差损失函数

看一下另一个损失函数的坑,MSE(Mean Square Error):

RuntimeError: Expected object of scalar type Float but got scalar type Long for argument #2 'target’

程序运行的经过为:

mse = nn.MSELoss()
loss = mse(result, batch_label)

导致原因为:

mse 方法要求 target 数据类型为 float,所以要对 batch_label 做数据类型转换:

# 可以直接在全部数据中转换
batch_label.float()

你可能感兴趣的:(深度学习,(DL),Python,PyTorch)