在神经网络模型训练时,有时候我们需要共享不同模型之间的网络参数,下面我将以一个案例展示一下如何共享模型训练参数。
⭐参数共享模块的模型结构必须完全一致才能实现参数共享
假设我们有以下两个模型:
class ANN1(nn.Module):
def __init__(self,features):
super(ANN1, self).__init__()
self.features = features
self.nn_same = torch.nn.Sequential(
nn.Linear(features, 128),
torch.nn.ReLU(),
)
self.nn_diff = torch.nn.Sequential(
nn.Linear(128, 1)
)
def forward(self, x):
# x(batch_size, features)
x = self.nn_same(x)
x = self.nn_diff(x)
return x
class ANN2(nn.Module):
def __init__(self,features):
super(ANN2, self).__init__()
self.features = features
self.nn_same = torch.nn.Sequential(
nn.Linear(features, 128),
torch.nn.ReLU(),
)
self.nn_diff = torch.nn.Sequential(
nn.Linear(128, 1)
)
def forward(self, x):
# x(batch_size, features)
x = self.nn_same(x)
x = self.nn_diff(x)
return x
model1 = ANN1(10)
model2 = ANN2(10)
print(model1)
print(model2)
ANN1(
(nn_same): Sequential(
(0): Linear(in_features=10, out_features=128, bias=True)
(1): ReLU()
)
(nn_diff): Sequential(
(0): Linear(in_features=128, out_features=1, bias=True)
)
)
ANN2(
(nn_same): Sequential(
(0): Linear(in_features=10, out_features=128, bias=True)
(1): ReLU()
)
(nn_diff): Sequential(
(0): Linear(in_features=128, out_features=1, bias=True)
)
)
其中 nn_same 代表要共享参数的模块,模块名称可以不相同,但是模块结构必须完全相同。
因为模型初始化时参数是随机初始化的,所以两个模型的参数肯定不相同。假如我们要将 model1 中 nn_same 模块的参数迁移到 model2 中的 nn_same 中,首先看一下 model1.nn_same 的参数:
for param_tensor in model1.nn_same.state_dict():#输出迁移前的参数
print(param_tensor, "\t", model1.nn_same.state_dict()[param_tensor])
0.weight tensor([[ 0.1321, -0.0178, 0.1631, ..., -0.2531, -0.1584, 0.0588],
[-0.2466, -0.0381, 0.2394, ..., -0.2924, -0.1267, -0.1791],
[-0.1713, -0.0716, 0.0598, ..., 0.1655, -0.1947, 0.0927],
...,
[-0.1795, -0.3082, -0.2846, ..., 0.2588, -0.0998, -0.1285],
[-0.2739, -0.1587, 0.1803, ..., -0.1905, -0.2832, -0.0724],
[ 0.1375, -0.1854, -0.1928, ..., 0.1470, 0.2928, 0.1385]])
0.bias tensor([-0.2251, -0.3036, 0.2147, -0.0798, -0.1079, -0.0396, -0.1078, 0.1006,
-0.1884, -0.0616, 0.0698, 0.0044, 0.1615, -0.2090, 0.0584, -0.0743,
···,
0.3010, -0.1674, 0.0982, 0.2267, -0.0865, -0.1350, -0.2501, 0.1475,
0.0187, 0.0819, 0.1840, -0.0988, 0.0133, -0.2082, 0.0376, 0.2993])
下面我们进行参数迁移:
print("****************迁移前*****************")
for param_tensor in model2.nn_same.state_dict():#输出迁移前的参数
print(param_tensor, "\t", model2.nn_same.state_dict()[param_tensor])
model_nn_same = model1.nn_same.state_dict() ##获取model的nn_same部分的参数
model2.nn_same.load_state_dict(model_nn_same,strict=True) #更新model2 nn_same部分的参数,#更新model2所有的参数,False表示跳过名称不同的层,True表示必须全部匹配(默认)
print("****************迁移后*****************")
for param_tensor in model2.nn_same.state_dict():#输出迁移后的参数
print(param_tensor, "\t", model2.nn_same.state_dict()[param_tensor])
#此时nn_same参数更新,nn_diff2参数不变
****************迁移前*****************
0.weight tensor([[-0.1030, -0.0111, 0.0989, ..., -0.3142, -0.0167, 0.0485],
[ 0.1671, 0.2833, 0.1353, ..., 0.1657, -0.2497, -0.1680],
[ 0.0470, 0.1208, 0.1707, ..., -0.0018, 0.2497, 0.0419],
...,
[-0.2406, -0.2757, 0.2527, ..., -0.0888, -0.2772, 0.1019],
[-0.3035, -0.0227, -0.0194, ..., 0.1280, -0.1167, 0.1060],
[ 0.0565, 0.1870, -0.2729, ..., -0.1215, 0.1343, -0.1057]])
0.bias tensor([ 0.0855, 0.3137, 0.2336, -0.2197, 0.0132, -0.1812, -0.1490, -0.1348,
0.1027, 0.0284, 0.1064, 0.2046, 0.1106, -0.2034, -0.1283, -0.1561,
···,
0.0328, -0.1035, -0.2942, -0.2368, -0.2290, 0.1846, -0.0270, 0.1286,
-0.2331, 0.1111, 0.2172, -0.2865, 0.2086, -0.1388, -0.2077, -0.2976])
****************迁移后*****************
0.weight tensor([[ 0.1321, -0.0178, 0.1631, ..., -0.2531, -0.1584, 0.0588],
[-0.2466, -0.0381, 0.2394, ..., -0.2924, -0.1267, -0.1791],
[-0.1713, -0.0716, 0.0598, ..., 0.1655, -0.1947, 0.0927],
...,
[-0.1795, -0.3082, -0.2846, ..., 0.2588, -0.0998, -0.1285],
[-0.2739, -0.1587, 0.1803, ..., -0.1905, -0.2832, -0.0724],
[ 0.1375, -0.1854, -0.1928, ..., 0.1470, 0.2928, 0.1385]])
0.bias tensor([-0.2251, -0.3036, 0.2147, -0.0798, -0.1079, -0.0396, -0.1078, 0.1006,
-0.1884, -0.0616, 0.0698, 0.0044, 0.1615, -0.2090, 0.0584, -0.0743,
···,
0.3010, -0.1674, 0.0982, 0.2267, -0.0865, -0.1350, -0.2501, 0.1475,
0.0187, 0.0819, 0.1840, -0.0988, 0.0133, -0.2082, 0.0376, 0.2993])
可以看到 model2 中 nn_same 模块的参数已经与 model1 中 nn_same 模块的参数一致。
假设我们有以下两个模型:
class ANN1(nn.Module):
def __init__(self,features):
super(ANN1, self).__init__()
self.features = features
self.nn_same1 = torch.nn.Sequential(
nn.Linear(features, 128),
torch.nn.ReLU(),
)
self.nn_same2 = torch.nn.Sequential(
nn.Linear(features, 128),
torch.nn.ReLU(),
)
self.nn_diff1 = torch.nn.Sequential(
nn.Linear(128, 1)
)
def forward(self, x):
# x(batch_size, features)
x = self.nn_same(x)
x = self.nn_diff(x)
return x
class ANN2(nn.Module):
def __init__(self,features):
super(ANN2, self).__init__()
self.features = features
self.nn_same1 = torch.nn.Sequential(
nn.Linear(features, 128),
torch.nn.ReLU(),
)
self.nn_same2 = torch.nn.Sequential(
nn.Linear(features, 128),
torch.nn.ReLU(),
)
self.nn_diff2 = torch.nn.Sequential(
nn.Linear(128, 1)
)
def forward(self, x):
# x(batch_size, features)
x = self.nn_same(x)
x = self.nn_diff(x)
return x
model1 = ANN1(10)
model2 = ANN2(10)
print(model1)
print(model2)
ANN1(
(nn_same1): Sequential(
(0): Linear(in_features=10, out_features=128, bias=True)
(1): ReLU()
)
(nn_same2): Sequential(
(0): Linear(in_features=10, out_features=128, bias=True)
(1): ReLU()
)
(nn_diff1): Sequential(
(0): Linear(in_features=128, out_features=1, bias=True)
)
)
ANN2(
(nn_same1): Sequential(
(0): Linear(in_features=10, out_features=128, bias=True)
(1): ReLU()
)
(nn_same2): Sequential(
(0): Linear(in_features=10, out_features=128, bias=True)
(1): ReLU()
)
(nn_diff2): Sequential(
(0): Linear(in_features=128, out_features=1, bias=True)
)
)
假如我们要将 model1 中 nn_same1 和 nn_same2 模块的参数迁移到 model2 中的 nn_same1 和 nn_same2 中:
print("****************迁移前*****************")
for param_tensor in model2.state_dict():#输出迁移前的参数
print(param_tensor, "\t", model2.state_dict()[param_tensor])
model_all = model1.state_dict() ##获取model的所有的参数
model2.load_state_dict(model_all,strict=False) #更新model2所有的参数,False表示跳过名称不同的层,True表示必须全部匹配(默认)
print("****************迁移后*****************")
for param_tensor in model2.state_dict():#输出迁移后的参数
print(param_tensor, "\t", model2.state_dict()[param_tensor])
#此时nn_same参数更新,nn_diff2参数不变
其中需要注意的是在model2.load_state_dict(mode_all,strict=False)
中strict=False
,表示两个模型的模块名不需要完全匹配,只会更新名称相同的模块。如果两个模型的模块名不完全相同但是strict=True
那么就会报错:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-56-069ae53e28f3> in <module>
4
5 model_all = model1.state_dict() ##获取model的所有的参数
----> 6 model2.load_state_dict(model_all,strict=True) #更新model2所有的参数,False表示跳过名称不同的层,True表示必须全部匹配(默认)
7
8 print("****************迁移后*****************")
D:\Anaconda3\envs\pytorch\lib\site-packages\torch\nn\modules\module.py in load_state_dict(self, state_dict, strict)
1481 if len(error_msgs) > 0:
1482 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1483 self.__class__.__name__, "\n\t".join(error_msgs)))
1484 return _IncompatibleKeys(missing_keys, unexpected_keys)
1485
RuntimeError: Error(s) in loading state_dict for ANN2:
Missing key(s) in state_dict: "nn_diff2.0.weight", "nn_diff2.0.bias".
Unexpected key(s) in state_dict: "nn_diff1.0.weight", "nn_diff1.0.bias".