Pytorch学习(十一)Pytorch中.item()的用法

我第一次接触.item()是在做图像分类任务中,计算loss的时候。total_loss = total_loss + loss.item()

1. .item()的用法

.item()用于在只包含一个元素的tensor中提取值,注意是只包含一个元素,否则的话使用.tolist()

x = torch.tensor([1])
print(x.item())
y = torch.tensor([2,3,4,5])
print(y.item())
# 输出结果如下
1
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-16-e884e6ada778> in <module>
      2 print(x.item())
      3 y = torch.tensor([2,3,4,5])
----> 4 print(y.item())

ValueError: only one element tensors can be converted to Python scalars

由结果看出,print(y.item())报错了,报错说只有一个元素的tensor才能使用.item()转换成标量,所以我们把y.item()改成.tolist()试试。

x = torch.tensor([1])
print(x.item())
y = torch.tensor([2,3,4,5])
print(y.tolist())
# 输出结果
1
[2, 3, 4, 5]

2. 为何要使用.item()

在训练时统计loss变化时,会用到loss.item(),能够防止tensor无线叠加导致的显存爆炸

3. 为何loss还需要乘batch_size呢

比如这个语句: train_loss += loss.item() * images.size(0) ,为什么总的损失train_loss不直接累加loss.item(),还在后边乘上images.size(0)呢?
loss.item()应该是一个batch size的平均损失,×images.size(0)那就是一个batch size的总损失,所以train_loss很可能是求一个epoch的loss之和。

你可能感兴趣的:(Pytorch系列学习,深度学习与神经网络,Python,pytorch,深度学习,python,人工智能)