pytorch学习:loss为什么要加item()

作者:陈诚
链接:https://www.zhihu.com/question/67209417/answer/344752405
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

PyTorch 0.4.0版本去掉了Variable,将Variable和Tensor融合起来,可以视Variable为requires_grad=True的Tensor。其动态原理还是不变。在获取数据的时候也变得更优雅:使用loss += loss.detach()来获取不需要梯度回传的部分。或者使用loss.item()直接获得所对应的python数据类型。============================================================

以下为原回答:算是动态图的一个坑吧。记录loss信息的时候直接使用了输出的Variable。应该不止我经历过这个吧…久久不用又会不小心掉到这个坑里去…for data, label in trainloader:

    loss = criterion(out, label)
    loss_sum += loss     # <--- 这里

运行着就发现显存炸了观察了一下发现随着每个batch显存消耗在不断增大…参考了别人的代码发现那句loss一般是这样写 /(ㄒoㄒ)/

loss_sum += loss.data[0]

这是因为输出的loss的数据类型是Variable。而PyTorch的动态图机制就是通过Variable来构建图。主要是使用Variable计算的时候,会记录下新产生的Variable的运算符号,在反向传播求导的时候进行使用。如果这里直接将loss加起来,系统会认为这里也是计算图的一部分,也就是说网络会一直延伸变大

那么消耗的显存也就越来越大

总之使用Variable的数据时候要非常小心。不是必要的话尽量使用Tensor来进行计算… 包括数据的输入时候,如果“过早”把数据丢到Variable里面去,那么可能也会被系统视为网络的一部分。所以,要投入的时候再把数据丢到Variable里面去吧~题外话想更多感受动态图的话,可以通过Variable的grad_fun来观察到该Variable是通过什么运算得到的(前提是前面的Variable的required_grad置为True)。

大概是这样

z = x + y
z.grad_fn
out:

你可能感兴趣的:(技术栈)