pytorch多GPU训练保存的模型,在单GPU环境下加载出错

背景
在公司用多卡训练模型,得到权值文件后保存,然后回到实验室,没有多卡的环境,用单卡训练,加载模型时出错,因为单卡机器上,没有使用DataParallel来加载模型,所以会出现加载错误。
原因
DataParallel包装的模型在保存时,权值参数前面会带有module字符,然而自己在单卡环境下,没有用DataParallel包装的模型权值参数不带module。本质上保存的权值文件是一个有序字典。
解决方法
1.在单卡环境下,用DataParallel包装模型。

2.自己重写Load函数,灵活。

from collections import OrderedDict
def myOwnLoad(model, check):
    modelState = model.state_dict()
    tempState = OrderedDict()
    for i in range(len(check.keys())-2):
        print modelState.keys()[i], check.keys()[i]
        tempState[modelState.keys()[i]] = check[check.keys()[i]]
    temp = [[0.02]*1024 for i in range(200)]  # mean=0, std=0.02
    tempState['myFc.weight'] = torch.normal(mean=0, std=torch.FloatTensor(temp)).cuda()
    tempState['myFc.bias']   = torch.normal(mean=0, std=torch.FloatTensor([0]*200)).cuda()

    model.load_state_dict(tempState)
    return model

你可能感兴趣的:(Pytorch学习,多GPU,模型加载)