关于torch.load的一点小tip

今天同事问我,他说好奇怪啊,为啥这个模型总是加载报错啊,总是报显存不足,喊我过去看看,我去瞄了一下,说你试试在torch.load里面加一个参数,指定cpu加载

torch.load('XXX.pth', map_location='cpu')

看了下内存变化,然后发现内存也不过是增加2G左右而已,而看了下显卡信息,明明是空的,那为啥会报显存不足的错呢。

此时又看了下,这台电脑是双显卡,另一张卡在跑程序,显存占用快满了,我就问了一句,你这张空卡加载别的模型可以么,回复可以的,再问,你能确定没有加载到这个满卡上,回复,怎么会呢,明明device写的是‘cuda:0’。

我就让他把torch.load里面的'cpu' 改为‘cuda:0’,发现一切正常了

这里面具体什么原理我也不太清楚,只是有一点,这个"map_location"的参数,如果你不指定的话,他是你这个模型在什么设备上保存,就会加载到什么设备上,所以他这个模型应该是在卡1上保存的,然后加载的时候默认加载到卡1上,然后报错了。

在这个问题上,差不多就是,你在gpu上训练的模型,在无gpu的设备上load,会报错是一样的道理,必须要指定"map_location"参数为'cpu',就这么简单,至于什么device,和这个没关系

网上下载过一些项目来借鉴、学习、使用的,看到个人觉得比较规范的写法就是先加载到cpu,再to到相应的device,这样可以避免,他使用的设备在我这边没有,而不可使用的情况,个人觉得还是比较有意义的。

 

附带一个小知识,以前在用tensorflow的时候,一般会一起装一个tensorboard,那时候也不知道是啥,也不会用,有没有曲线输出,完全取决于git上别人的代码有没有写,后来用pytorch,也就没有这个玩意了,更没操心了,后来下载代码的时候,发现有个库叫tensorboardX,这个就可以提供绘制曲线的功能了,再后来pytorch自己集成了,使用上,似乎是和tensorboardX是一样的

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('./log')

writer.add_scalar('Loss/train', loss.item(), global_step)

writer.add_scalar('Acc/acc1', acc1[0], global_step)
writer.add_scalar('Acc/acc5', acc5[0], global_step)

先引用这个包,然后创建一个写入实例,这里就提供一下生成曲线的api,他就是每次画一个点,横坐标就是global_step,纵坐标就是loss了,然后多个点呈现的时候会自动连城线,前面的路径一样的东西,这么写的话,他会有一个栏目叫Loss,下面有一张图叫train,需要注意的是,这个global_step应该是全局的哦,不要下一个循环就变为0了,这样你会看到这个曲线会很奇怪的

关于torch.load的一点小tip_第1张图片

 

终端里面敲“tensorboard --logdir=./log”,然后就会出现一个网址,点击进去就能看到了,这个./log就是前面创建写入实例的时候的路径哈,其实这个路径里面是可以有多个文件夹的,这样你指定母文件夹的话,他会把每个文件夹当成一个项目来呈现,用不同颜色来表示,感兴趣可以自己试试

知识很小,但是谁还不是AI小学生呢,大家一起努力

你可能感兴趣的:(小知识,深度学习,pytorch)