PyTorch Variable与动态图

https://mp.weixin.qq.com/s/OMjfck4jlMneGZ1NJxbjKQ

for data, label in trainloader:
    ......
    out = model(data)
    loss = criterion(out, label)
    loss_sum += loss     # <--- 这里
    ......
    # 正确的写法:loss_sum += loss.data[0]

这是因为输出的loss的数据类型是Variable。

而PyTorch的动态图机制就是通过Variable来构建图。主要是使用Variable计算的时候,会记录下新产生的Variable的运算符号,在反向传播求导的时候进行使用。

如果这里直接将loss加起来,系统会认为这里也是计算图的一部分,也就是说网络会一直延伸变大~那么消耗的显存也就越来越大~~

你可能感兴趣的:(Python)