1、保存与调用
def modelfunc(nn.Module):
# 之前定义好的模型
# 由于pytorch没有像keras那样有保存模型结构的API,因此,每次load之前必须找到模型的结构。
model_object = modelfunc # 导入模型结构
# 保存和加载整个模型
torch.save(model_object, 'model.pth')
model = torch.load('model.pth')
# 仅保存和加载模型参数
torch.save(model_object.state_dict(), 'params.pth')
model_object.load_state_dict(torch.load('params.pth'))
2、torch.load 的输出:
# 保存和加载整个模型
torch.save(model_object, 'model.pth')
model = torch.load('model.pth')
print(model)
>>>【结果】
modelfunc(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): Bottleneck(
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(downsample): Sequential(
(0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
)
(2): Bottleneck(
(conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
)
)
(layer2): Sequential(
(0): Bottleneck(
(conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
)
(2): Bottleneck(
(conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
)
(3): Bottleneck(
(conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
)
)
)
3、model_object.load_state_dict(torch.load('params.pth')) 参数的输出
# 仅保存和加载模型参数
torch.save(model_object.state_dict(), 'params.pth')
dic = torch.load('params.pth')
model_object.load_state_dict(dic)
print(dic)
>>>【结果】
OrderedDict([('conv1.weight', tensor([[[[ 1.3335e-02, 1.4664e-02, -1.5351e-02, ..., -4.0896e-02,
-4.3034e-02, -7.0755e-02],
[ 4.1205e-03, 5.8477e-03, 1.4948e-02, ..., 2.2060e-03,
-2.0912e-02, -3.8517e-02],
[ 2.2331e-02, 2.3595e-02, 1.6120e-02, ..., 1.0281e-01,
6.2641e-02, 5.1977e-02],
...,
('bn1.weight', tensor([ 2.3888e-01, 2.9136e-01, 3.1615e-01, 2.7122e-01, 2.1731e-01,
3.0903e-01, 2.2937e-01, 2.3086e-01, 2.1129e-01, 2.8054e-01,
1.9923e-01, 3.1894e-01, 1.7955e-01, 1.1246e-08, 1.9704e-01,
2.0996e-01, 2.4317e-01, 2.1697e-01, 1.9415e-01, 3.1569e-01,
1.9648e-01, 2.3214e-01, 2.1962e-01, 2.1633e-01, 2.4357e-01,
2.9683e-01, 2.3852e-01, 2.1162e-01, 1.4492e-01, 2.9388e-01,
2.2911e-01, 9.2716e-02, 4.3334e-01, 2.0782e-01, 2.7990e-01,
3.5804e-01, 2.9315e-01, 2.5306e-01, 2.4210e-01, 2.1755e-01,
3.8645e-01, 2.1003e-01, 3.6805e-01, 3.3724e-01, 5.0826e-01,
1.9341e-01, 2.3914e-01, 2.6652e-01, 3.9020e-01, 1.9840e-01,
2.1694e-01, 2.6666e-01, 4.9806e-01, 2.3553e-01, 2.1349e-01,
2.5951e-01, 2.3547e-01, 1.7579e-01, 4.5354e-01, 1.7102e-01,
2.4903e-01, 2.5148e-01, 3.8020e-01, 1.9665e-01])),
('bn1.bias', tensor([ 2.2484e-01, 6.0617e-01, 1.2483e-02, 1.3270e-01, 1.8030e-01,
1.4739e-01, 1.7430e-01, 1.9023e-01, 2.3226e-01, 2.0082e-01,
1.2834e-01, -2.1285e-01, 1.5065e-01, -3.9217e-08, 2.4985e-01,
2.0454e-01, 5.4934e-01, 2.1021e-01, 2.2505e-01, 4.6484e-01,
2.3888e-01, 2.0442e-01, 2.1546e-01, 6.6194e-01, 2.2755e-01,
6.6069e-01, 2.0587e-01, 1.9292e-01, 1.1195e-01, 3.3785e-01,
1.2393e-01, 4.1079e-02, 7.7150e-01, 2.6964e-01, 3.3347e-01,
5.7908e-01, 1.5026e-01, 1.7534e-01, 1.9429e-01, 1.7248e-01,
8.0577e-01, 2.3693e-01, -4.3369e-01, 8.4813e-01, -3.7857e-01,
2.4787e-01, 1.8101e-01, 3.2949e-01, -2.8598e-01, 2.2717e-01,
2.6168e-01, 5.7609e-02, -5.0320e-01, 1.5704e-01, 1.7890e-01,
2.8114e-01, 4.2167e-01, -9.7650e-02, -3.1231e-01, -2.5637e-02,
8.8566e-02, 1.8052e-01, 8.3045e-01, 2.5015e-01])),
('bn1.running_mean', tensor([ 2.8781e-02, 1.0830e-01, 2.6812e-01, -4.7955e-02, -2.7350e-02,
-1.2350e-02, -2.8534e-02, 3.8390e-02, 8.6643e-03, 1.1076e-01,
-1.6231e-02, -7.1499e-01, 5.7644e-02, -5.1895e-07, -1.9860e-02,
6.5988e-03, 4.9869e-01, -3.4726e-02, -2.2373e-02, -6.4198e-01,
3.3326e-02, 6.5970e-02, 3.1869e-02, 3.1863e-01, 3.7692e-02,
4.9075e-01, 3.0402e-02, -6.5330e-02, -2.4589e-02, 4.3018e-01,
-6.3207e-02, 3.6987e-02, -7.9438e-01, 3.7037e-02, 8.1242e-01,
-8.8931e-01, -3.4412e-02, -1.6578e-01, -1.8018e-02, -2.7667e-02,
-1.3835e+00, 7.8008e-02, -7.0342e-01, 3.4551e-01, 5.7252e-01,
4.5663e-02, 5.2766e-02, 2.8974e-01, -3.4401e-01, 1.6897e-02,
9.7269e-02, -2.1634e-02, 7.9793e-01, 1.7612e-02, -3.2805e-03,
-1.7782e-01, -1.4005e-01, 4.1215e-02, 7.2888e-01, -2.2417e-01,
1.9287e-03, 8.7772e-02, 1.3144e+00, -3.8825e-02])),
('bn1.running_var', tensor([ 5.0796e-01, 1.4441e+00, 3.3001e+00, 3.3098e+00, 1.3029e-01,
3.3023e+00, 1.2143e-01, 2.5986e-01, 8.9925e-02, 2.9480e+00,
1.3752e-01, 2.1341e+00, 6.9679e-02, 2.7234e-12, 2.4457e-02,
6.9063e-02, 1.1395e+00, 8.0611e-02, 2.1984e-02, 2.6701e+00,
5.6415e-02, 2.1792e-01, 1.0816e-01, 9.8851e-01, 3.0843e-01,
2.9959e+00, 5.4037e-02, 1.7887e-01, 2.8518e-02, 1.8343e+00,
7.0009e-01, 2.9475e-02, 1.1048e+01, 7.5987e-03, 2.6686e+00,
5.0308e+00, 2.8717e+00, 1.7434e+00, 3.8133e-01, 1.3055e-01,
8.6697e+00, 3.9596e-02, 2.3990e+00, 3.7014e+00, 6.9698e+00,
1.2682e-01, 1.4923e-01, 1.5581e+00, 1.1554e+00, 2.0051e-02,
1.3014e-01, 9.9781e-01, 3.6349e+00, 2.4568e-01, 1.2094e-01,
7.6329e-01, 7.9295e-01, 1.5916e-01, 3.8380e+00, 3.2014e-01,
3.4269e-01, 3.3512e-01, 8.0546e+00, 2.4255e-02])), ('layer1.0.conv1.weight', tensor([[[[ 3.5144e-03]],
[[ 3.9855e-02]],
[[-2.4795e-02]],
...,
('layer1.0.bn1.weight', tensor([ 2.1341e-01, 1.8848e-01, 1.4136e-01, 1.5273e-01, 1.3220e-01,
1.8735e-01, 1.4475e-01, 4.5110e-08, 1.5993e-01, 1.4946e-01,
2.3499e-01, 1.8315e-01, 1.8516e-01, 1.4933e-01, 1.3090e-01,
1.0634e-01, 3.7487e-01, 1.2644e-01, 3.1895e-01, 2.7160e-01,
2.5810e-01, 2.9458e-01, 1.8395e-01, 2.1088e-08, 3.3313e-01,
2.0461e-01, 3.0399e-01, 1.1805e-08, 1.4977e-01, 1.5719e-01,
1.4011e-01, 1.4900e-01, 1.2438e-01, 1.8786e-01, 1.4257e-01,
3.4828e-01, 1.5038e-01, 3.0034e-01, 2.5925e-01, 1.0711e-01,
2.6875e-01, 1.3552e-01, 1.1822e-01, 1.1189e-01, 2.8736e-01,
3.2637e-01, 1.4781e-01, 2.3105e-01, 3.3638e-01, 2.8808e-01,
1.2319e-01, 3.0763e-01, 1.1846e-01, 1.3137e-01, 2.0671e-01,
1.5787e-01, 2.6574e-08, 2.0467e-01, 2.8797e-08, 1.8284e-01,
3.0180e-01, 1.7401e-01, 2.8438e-01, 2.3715e-01])),
('layer1.0.bn1.bias', tensor([ 4.3266e-01, 4.6854e-02, -8.0134e-02, 7.3302e-02, 2.7970e-01,
-7.8047e-03, 9.4087e-02, -1.0086e-07, -1.4034e-01, -5.1599e-02,
4.4470e-02, 2.1814e-01, 4.0718e-02, 1.1979e-01, 1.4432e-01,
1.3672e-01, -1.1168e-01, 1.4774e-01, -1.2879e-01, -5.3147e-02,
-3.3920e-02, -2.0600e-02, 6.2783e-02, -6.5736e-08, -7.1213e-02,
6.9510e-02, -1.3264e-01, -6.4411e-08, -2.8908e-02, 9.4164e-02,
2.4790e-01, -8.2850e-02, -2.8872e-02, -1.7086e-01, 9.9522e-02,
-1.1357e-01, 1.9770e-01, 1.4800e-02, -7.0896e-02, 1.0722e-01,
1.2536e-02, -3.6633e-02, 1.4959e-01, 1.0533e-01, 2.0933e-02,
-1.0502e-01, -4.8848e-02, 4.9007e-01, -1.4755e-01, -1.0900e-01,
1.9815e-02, -7.0964e-02, -4.6543e-02, 1.0874e-01, -2.7878e-01,
4.4500e-03, -7.7156e-08, 7.5060e-02, -8.4474e-08, 2.2533e-01,
-7.1593e-02, -1.5823e-01, -3.4459e-02, 5.2894e-01])),
('layer1.0.bn1.running_mean', tensor([-6.0619e-01, -3.5467e-01, 2.4651e-01, -2.5210e-01, -7.6892e-02,
-3.3654e-01, -1.0111e-01, -1.7881e-08, 2.1631e-01, -2.8016e-01,
-3.1948e-01, 1.1134e+00, -1.1791e-01, -2.0125e-01, -3.2957e-01,
-2.6431e-02, -3.4833e-01, 7.1402e-01, -2.7727e-01, -2.7576e-01,
-1.7791e-01, -1.1054e-01, -1.5952e-01, -5.6052e-45, -3.6867e-01,
-1.7413e-01, -2.6344e-01, 4.3125e-09, -2.3616e-01, -3.0546e-01,
-1.8908e-02, 2.2109e-01, 1.1146e-02, -1.4291e-01, -3.0156e-01,
-4.4344e-01, -2.2829e-01, -2.0861e-01, -2.2197e-01, 3.1603e-01,
-1.1507e-01, -1.3784e-01, -2.9271e-01, -4.8246e-01, -1.5741e-01,
-2.6682e-01, -3.8136e-01, -3.1360e-01, -1.9755e-01, -4.1116e-01,
-2.8717e-02, -3.0186e-01, 8.8766e-02, -3.3887e-01, -5.9848e-02,
-6.4817e-01, 1.2924e-09, -2.2738e-01, -5.6052e-45, 1.0252e+00,
-9.3871e-02, -1.4969e-02, -4.0218e-01, -1.3630e-01])),