nn.DataParallel权重保存和读取,单卡单机权重保存和读取,二者之间的转换。

在自己电脑上(单卡)调试好模型,然后放到服务器(多卡)上跑,设置成了多卡训练,保存的模型字典中自动都增加了一个module,导致我在自己电脑上加载时候checkpoints不匹配。所以有了这份记录。

出处:pytorch 使用DataParallel 单机多卡和单卡保存和加载模型的正确方法 - 知乎 (zhihu.com)

1.单卡训练,单卡加载

这里我为了把三个模块save到同一个文件里,我选择对所有的模型先封装成一个checkpoint字典,然后保存到同一个文件里,这样就可以在加载时只需要加载一个参数文件。

保存:

states = {
        'state_dict_encoder': encoder.state_dict(),
        'state_dict_decoder': decoder.state_dict(),
    }
torch.save(states, fname)

加载:

#先初始化模型,因为保存时只保存了模型参数,没有保存模型整个结构

encoder = Encoder()

decoder = Decoder()

#然后加载参数

checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置

encoder_state_dict=checkpoint['state_dict_encoder']

decoder_state_dict=checkpoint['state_dict_decoder']

encoder.load_state_dict(encoder_state_dict)

decoder.load_state_dict(decoder_state_dict)

2.单卡训练,多卡加载

保存:

保存过程一样,不做任何改变

states = {
        'state_dict_encoder': encoder.state_dict(),
        'state_dict_decoder': decoder.state_dict(),
    }
torch.save(states, fname)

加载:

加载过程也没有任何改变,但是要注意先加载模型参数,再对模型做并行化处理

#先初始化模型,因为保存时只保存了模型参数,没有保存模型整个结构
encoder = Encoder()
decoder = Decoder()
#然后加载参数
checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置
encoder_state_dict=checkpoint['state_dict_encoder']
decoder_state_dict=checkpoint['state_dict_decoder']
encoder.load_state_dict(encoder_state_dict)
decoder.load_state_dict(decoder_state_dict)
#并行处理模型
encoder = nn.DataParallel(encoder)
decoder = nn.DataParallel(decoder)

3.多卡训练,单卡加载

注意,如果你考虑到以后可能需要单卡加载你多卡训练的模型,建议在保存模型时,去除模型参数字典里面的module,如何去除呢,使用model.module.state_dict()代替model.state_dict()

保存:

states = {
        'state_dict_encoder': encoder.module.state_dict(), #不是encoder.state_dict()
        'state_dict_decoder': decoder.module.state_dict(),
    }
torch.save(states, fname)

加载:

要注意由于我们保存的方式是以单卡的方式保存的,所以还是要先加载模型参数,再对模型做并行化处理

#先初始化模型,因为保存时只保存了模型参数,没有保存模型整个结构
encoder = Encoder()
decoder = Decoder()
#然后加载参数
checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置
encoder_state_dict=checkpoint['state_dict_encoder']
decoder_state_dict=checkpoint['state_dict_decoder']
encoder.load_state_dict(encoder_state_dict)
decoder.load_state_dict(decoder_state_dict)
#并行处理模型
encoder = nn.DataParallel(encoder)
decoder = nn.DataParallel(decoder)
同时,你也可以用第二种方式去保存和加载:

3.多卡训练,单卡加载,方法二

使用model.state_dict()保存,但是单卡加载的时候,要把模型做并行化(在单卡上并行)

保存:

states = {
        'state_dict_encoder': encoder.state_dict(), 
        'state_dict_decoder': decoder.state_dict(),
    }
torch.save(states, fname)

加载:

要注意由于我们保存的方式是以多卡的方式保存的,所以无论你加载之后的模型是在单卡运行还是在多卡运行,都先把模型并行化再去加载

#先初始化模型,因为保存时只保存了模型参数,没有保存模型整个结构
encoder = Encoder()
decoder = Decoder()
#并行处理模型
encoder = nn.DataParallel(encoder)
decoder = nn.DataParallel(decoder)
#然后加载参数
checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置
encoder_state_dict=checkpoint['state_dict_encoder']
decoder_state_dict=checkpoint['state_dict_decoder']
encoder.load_state_dict(encoder_state_dict)
decoder.load_state_dict(decoder_state_dict)

4.多卡保存,多卡加载

这就和多卡保存,单卡加载第二中方式一样了 使用model.state_dict()保存,加载的时候,要先把模型做并行化(在多卡上并行)

保存:

states = {
        'state_dict_encoder': encoder.state_dict(), 
        'state_dict_decoder': decoder.state_dict(),
    }
torch.save(states, fname)

加载: 要注意由于我们保存的方式是以多卡的方式保存的,所以无论你加载之后的模型是在单卡运行还是在多卡运行,都先把模型并行化再去加载

#先初始化模型,因为保存时只保存了模型参数,没有保存模型整个结构
encoder = Encoder()
decoder = Decoder()
#并行处理模型
encoder = nn.DataParallel(encoder)
decoder = nn.DataParallel(decoder)
#然后加载参数
checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置
encoder_state_dict=checkpoint['state_dict_encoder']
decoder_state_dict=checkpoint['state_dict_decoder']
encoder.load_state_dict(encoder_state_dict)
decoder.load_state_dict(decoder_state_dict)

你可能感兴趣的:(pytorch,python,深度学习,机器学习)