在pytorch中,torch.nn.Module模块中的state_dict变量存放训练过程中需要学习的权重和偏执系数,state_dict作为python的字典对象将每一层的参数映射成tensor张量,需要注意的是torch.nn.Module模块中的state_dict只包含卷积层和全连接层的参数,当网络中存在batchnorm时,例如vgg网络结构,torch.nn.Module模块中的state_dict也会存放batchnorm’s running_mean,关于batchnorm详解可见https://blog.csdn.net/wzy_zju/article/details/81262453
torch.optim模块中的Optimizer优化器对象也存在一个state_dict对象,此处的state_dict字典对象包含state和param_groups的字典对象,而param_groups key对应的value也是一个由学习率,动量等参数组成的一个字典对象。
因为state_dict本质上Python字典对象,所以可以很好地进行保存、更新、修改和恢复操作(python字典结构的特性),从而为PyTorch模型和优化器增加了大量的模块化。
作用:用来加载torch.save()
保存的模型文件。
使用方式
torch.load(f, map_location=None, pickle_module=<module 'pickle' from '/opt/conda/lib/python3.6/pickle.py'>, **pickle_load_args)
参数解释:
f:权重文件地址
map_location:设备 CPU还是GPU
后两个参数可以不用管
d1.update(d2)的作用是,将字典d2的内容合并到d1中,
其中d2中的键值对但d1中没有的键值对会增加到d1中去,
两者都有的键值对更新为d2的键值对.
d1 = {"浙江":"杭州","江苏":"nanjing"}
d1
{'浙江': '杭州', '江苏': 'nanjing'}
d1.update(江苏="南京")
d1
{'浙江': '杭州', '江苏': '南京'}
d2 = {"山东":"济南","河北":"石家庄"}
d1
{'浙江': '杭州', '江苏': '南京'}
d1.update(d2)
d1
{'浙江': '杭州', '江苏': '南京', '山东': '济南', '河北': '石家庄'}
d3 = {"浙江":"杭州市*****"}
d1
{'浙江': '杭州', '江苏': '南京', '山东': '济南', '河北': '石家庄'}
d1.update(d3)
d1
{'浙江': '杭州市*****', '江苏': '南京', '山东': '济南', '河北': '石家庄'}
Python 字典(Dictionary) items() 函数以列表返回可遍历的(键, 值) 元组数组。
items() 方法把字典中每对 key 和 value 组成一个元组,并把这些元组放在列表中返回
与state_dict相比,我理解的是,load_state_dict是更新好的权重放回去,state_dict是将权重系数取出来。
权重的预加载可以综合到这几步
if G_model_path != '':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_dict = G_model.state_dict()
pretrained_dict = torch.load(G_model_path, map_location=device)
pretrained_dict = {k : v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
model_dict.update(pretrained_dict)
G_model.load_state_dict(model_dict)
第一步:获取当前设备到底是GPU还是CPU
第二步:取出当前还未加载权重的字典
第三步:用torch.load获取与训练好的新的权重的字典
第四步:在第三步的字典中,判断第三步新的权重和原始模型的权重的大小shape是否一致
如果一致,新的权重字典就保留这个(键、值)权重
如果不一致,新的权重字典就舍去这个(键、值)权重
第五步:用第四步最新的 字典 来更新第二步的原始字典
第六步:用第五步更新后的权重字典放回模型中。