pytorch中保存网络损失loss简单代码

整体思路为,将每个iter 和每次输出的loss进行保存,然后利用plt进行绘图。

1## 保存loss和iteration

loss = []
iteration=[]
iteration.append(i)		#i是你的iter
loss.append(total_loss.item())		#total_loss.item()是你每一次inter输出的loss
    

2## plt绘图

	#num_iter是你总的迭代次数
    if i==num_iter-1:
        plt.figure()
        plt.plot(iteration, loss, label="loss")
        plt.draw()
        plt.show()

(后序:脑子不好的人只能记录每一个细节,如果还能帮到和我一样的小白,就超级好。嘻嘻?)

你可能感兴趣的:(深度网络训练,loss,pytorch,plt)