Pytorch1.0 与 0.4版的兼容问题解决

问题描述

实验室服务器上安装的是pytorch1.0 GPU版本,但是自己的本地的电脑上是早期安装的pytorch0.4.1CPU版。在服务器上跑完了跑完了一个网络,想要直接移植到本地电脑上运行就出现了问题。

GPU版的pytorch训练的模型怎么在CPU的框架上运行

因为GPU的版pytorch训练的模型里有cuda相关的属性,而CPU版的pytorch里没有,直接将GPU模型放到CPU框架上执行下面语句会报错,正确的方式是先在GPU版的pytorch上将原来的模型转换为CPU模型

model_state = torch.load('\ModulePath\net.pkl')
net.load_state_dict(model_state)

转换模型类型操作

net = Net(*parameters)
#这一行指定load 到cpu很关键
model_state = torch.load('\ModulePath\net.pkl',map_location=torch.device('cpu'))
net.load_state_dict(model_state)
#再重新保存模型就可以在CPU版的pytorch上运行了
torch.save(net.state_dict(),'\ModulePath\net.pkl')

pytorch 0.4 与 pytorch 1.0兼容问题

在 1.0版本的pytorch中,module.parameters()里比0.4 多了一类“batches_tracked”参数,用来追踪每个可学习的参数训练了多少次,但是在0.4里是没有的,其他的是一样的,模型的参数都保存在model_state中,它是一个有序字典型对象(OrderedDict)。

Pytorch1.0 与 0.4版的兼容问题解决_第1张图片
如果能够把1.0版的model_state中的含有“num_batches_tracked”的键以及对应的值都删除掉那么久可以顺利使用了,下面给出model_state的修改代码

#定义需要删除的键
delkeys = []
#寻找需要删除的键的完整名称
for k in model_state:
	if 'batches_tracked' in k
		delkeys.append(k)
#删除找到的键及对应的值
for delk in delkeys:
	model_state.__delitem__(delk)

这样就能把pytorch 1.0上的网络修改到0.4上可以用了。

你可能感兴趣的:(经验帖)