pytorch专题 --- load模型

转载自 https://blog.csdn.net/hungryof/article/details/81364487
博客最后加了一段重载模型的另一种方法

一般来说,保存模型是把参数全部用model.cpu().state_dict(), 然后加载模型时一般用 model.load_state_dict(torch.load(model_path))。 值得注意的是:torch.load 返回的是一个 OrderedDict.

import torch
import torch.nn as nn

class Net_old(nn.Module):
    def __init__(self):
        super(Net_old, self).__init__()
        self.nets = nn.Sequential(
            torch.nn.Conv2d(1, 2, 3),
            torch.nn.ReLU(True),
            torch.nn.Conv2d(2, 1, 3),
            torch.nn.ReLU(True),
            torch.nn.Conv2d(1, 1, 3)
        )
    def forward(self, x):
        return self.nets(x)

class Net_new(nn.Module):
    def __init__(self):
        super(Net_old, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 2, 3)
        self.r1 = torch.nn.ReLU(True)
        self.conv2 = torch.nn.Conv2d(2, 1, 3)
        self.r2 = torch.nn.ReLU(True)
        self.conv3 = torch.nn.Conv2d(1, 1, 3)

    def forward(self, x):
        x = self.conv1(x)
        x = self.r1(x)
        x = self.conv2(x)
        x = self.r2(x)
        x = self.conv3(x)
        return x

network = Net_old()
torch.save(network.cpu().state_dict(), 't.pth')

pretrained_net = torch.load('t.pth')
print(pretrained_net)

for key, v in enumerate(pretrained_net):
    print key, v
 
 
   
   
   
   
 
 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41

可以看到

OrderedDict([('nets.0.weight',
(0 ,0 ,.,.) =
 -0.2436  0.2523  0.3097
 -0.0315 -0.1307  0.0759
  0.0750  0.1894 -0.0761

(1 ,0 ,.,.) =
  0.0280 -0.2178  0.0914
  0.3227 -0.0121 -0.0016
 -0.0654 -0.0584 -0.1655
[torch.FloatTensor of size 2x1x3x3]
), ('nets.0.bias',
-0.0507
-0.2836
[torch.FloatTensor of size 2]
), ('nets.2.weight',
(0 ,0 ,.,.) =
 -0.2233  0.0279 -0.0511
 -0.0242 -0.1240 -0.0511
  0.2266  0.1385 -0.1070

(0 ,1 ,.,.) =
 -0.0943 -0.1403  0.0979
 -0.2163  0.1906 -0.2269
 -0.1984  0.0843 -0.0719
[torch.FloatTensor of size 1x2x3x3]
), ('nets.2.bias',
-0.1420
[torch.FloatTensor of size 1]
), ('nets.4.weight',
(0 ,0 ,.,.) =
  0.1981 -0.0250  0.2429
  0.3012  0.2428 -0.0114
  0.2878 -0.2134  0.1173
[torch.FloatTensor of size 1x1x3x3]
), ('nets.4.bias',
1.00000e-02 *
 -5.8426
[torch.FloatTensor of size 1]
)])
0 nets.0.weight
1 nets.0.bias
2 nets.2.weight
3 nets.2.bias
4 nets.4.weight
5 nets.4.bias
 
 
   
   
   
   
 
 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46

说明.state_dict()只是把所有模型的参数都以OrderedDict的形式存下来。通过

for key, v in enumerate(pretrained_net):
    print key, v
 
 
   
   
   
   
 
 
   
   
   
   
  • 1
  • 2

得知这些参数的顺序!,当然要看具体的值

for key, v in pretrained_net.items():
    print key, v
 
 
   
   
   
   
 
 
   
   
   
   
  • 1
  • 2
nets.0.weight
(0 ,0 ,.,.) =
 -0.2444 -0.3148  0.1626
  0.2531 -0.0859 -0.0236
  0.1635  0.1113 -0.1110

(1 ,0 ,.,.) =
  0.2374 -0.2931 -0.1806
 -0.1456  0.2264 -0.0114
  0.1813  0.1134 -0.2095
[torch.FloatTensor of size 2x1x3x3]

nets.0.bias
-0.3087
-0.2407
[torch.FloatTensor of size 2]

nets.2.weight
(0 ,0 ,.,.) =
 -0.2206 -0.1151 -0.0783
  0.0723 -0.2008  0.0568
 -0.0964 -0.1505 -0.1203

(0 ,1 ,.,.) =
  0.0131  0.1329 -0.1763
  0.1276 -0.2025 -0.0075
 -0.1167 -0.1833  0.1103
[torch.FloatTensor of size 1x2x3x3]

nets.2.bias
-0.1858
[torch.FloatTensor of size 1]

nets.4.weight
(0 ,0 ,.,.) =
 -0.1019  0.0534  0.2018
 -0.0600 -0.1389 -0.0275
  0.0696  0.0360  0.1560
[torch.FloatTensor of size 1x1x3x3]

nets.4.bias
1.00000e-03 *
 -5.6003
[torch.FloatTensor of size 1]
 
 
   
   
   
   
 
 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44

如果哪一天我们需要重新写这个网络的,比如使用Net_new,这个网络是将每一层都作为类的一个属性。如果直接load

import torch
import torch.nn as nn

class Net_old(nn.Module):
    def __init__(self):
        super(Net_old, self).__init__()
        self.nets = nn.Sequential(
            torch.nn.Conv2d(1, 2, 3),
            torch.nn.ReLU(True),
            torch.nn.Conv2d(2, 1, 3),
            torch.nn.ReLU(True),
            torch.nn.Conv2d(1, 1, 3)
        )
    def forward(self, x):
        return self.nets(x)

class Net_new(nn.Module):
    def __init__(self):
        super(Net_new, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 2, 3)
        self.r1 = torch.nn.ReLU(True)
        self.conv2 = torch.nn.Conv2d(2, 1, 3)
        self.r2 = torch.nn.ReLU(True)
        self.conv3 = torch.nn.Conv2d(1, 1, 3)

    def forward(self, x):
        x = self.conv1(x)
        x = self.r1(x)
        x = self.conv2(x)
        x = self.r2(x)
        x = self.conv3(x)
        return x

network = Net_old()
torch.save(network.cpu().state_dict(), 't.pth')

pretrained_net = torch.load('t.pth')

# Show keys of pretrained model
for key, v in pretrained_net.items():
    print key

# Define new network, and directly load the state_dict
new_network = Net_new()
new_network.load_state_dict(pretrained_net)

 
 
   
   
   
   
 
 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46

会出现unexpected key

nets.0.weight
nets.0.bias
nets.2.weight
nets.2.bias
nets.4.weight
nets.4.bias
Traceback (most recent call last):
  File "Blog.py", line 44, in 
    new_network.load_state_dict(pretrained_net)
  File "/home/vis/xxx/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 522, in load_state_dict
    .format(name))
KeyError: 'unexpected key "nets.0.weight" in state_dict'
 
 
   
   
   
   
 
 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

这是因为,我们新的网络,都是“属性形式的”,查看新网络的state_dict

conv1.weight
conv1.bias
conv2.weight
conv2.bias
conv3.weight
conv3.bias
 
 
   
   
   
   
 
 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

strict=False加载模型的正确解读

你可能会决定

import torch
import torch.nn as nn

class Net_old(nn.Module):
    def __init__(self):
        super(Net_old, self).__init__()
        self.nets = nn.Sequential(
            torch.nn.Conv2d(1, 2, 3),
            torch.nn.ReLU(True),
            torch.nn.Conv2d(2, 1, 3),
            torch.nn.ReLU(True),
            torch.nn.Conv2d(1, 1, 3)
        )
    def forward(self, x):
        return self.nets(x)

class Net_new(nn.Module):
    def __init__(self):
        super(Net_new, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 2, 3)
        self.r1 = torch.nn.ReLU(True)
        self.conv2 = torch.nn.Conv2d(2, 1, 3)
        self.r2 = torch.nn.ReLU(True)
        self.conv3 = torch.nn.Conv2d(1, 1, 3)

    def forward(self, x):
        x = self.conv1(x)
        x = self.r1(x)
        x = self.conv2(x)
        x = self.r2(x)
        x = self.conv3(x)
        return x

old_network = Net_old()
torch.save(old_network.cpu().state_dict(), 't.pth')

pretrained_net = torch.load('t.pth')

# Show keys of pretrained model
for key, v in pretrained_net.items():
    print key
print('****Before loading********')
new_network = Net_new()
print(torch.sum(old_network.nets[0].weight.data))
print(torch.sum(new_network.conv1.weight.data))
for key, _ in new_network.state_dict().items():
    print key
print('-----After loading------')
new_network.load_state_dict(pretrained_net, strict=False)
# So you think that this two values are the same?? Hah!
print(torch.sum(old_network.nets[0].weight.data))
print(torch.sum(new_network.conv1.weight.data))
for key, _ in new_network.state_dict().items():
    print key
 
 
   
   
   
   
 
 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54

输出

nets.0.weight
nets.0.bias
nets.2.weight
nets.2.bias
nets.4.weight
nets.4.bias
****Before loading********
-0.882688805461
0.34207585454
conv1.weight
conv1.bias
conv2.weight
conv2.bias
conv3.weight
conv3.bias
-----After loading------
-0.882688805461
0.34207585454
conv1.weight
conv1.bias
conv2.weight
conv2.bias
conv3.weight
conv3.bias
 
 
   
   
   
   
 
 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

数值一点变化都没有,说明“strict=False”没有那么智能! 它直接忽略那些没有的dict,有相同的就复制,没有就直接放弃赋值!

import torch
import torch.nn as nn

class Net_old(nn.Module):
    def __init__(self):
        super(Net_old, self).__init__()
        self.nets = nn.Sequential(
            torch.nn.Conv2d(1, 2, 3),
            torch.nn.ReLU(True),
            torch.nn.Conv2d(2, 1, 3),
            torch.nn.ReLU(True),
            torch.nn.Conv2d(1, 1, 3)
        )
    def forward(self, x):
        return self.nets(x)

class Net_new(nn.Module):
    def __init__(self):
        super(Net_new, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 2, 3)
        self.r1 = torch.nn.ReLU(True)
        self.conv2 = torch.nn.Conv2d(2, 1, 3)
        self.r2 = torch.nn.ReLU(True)
##### 在Net_new也加入了一个'nets'属性
        self.nets = nn.Sequential(
            torch.nn.Conv2d(1, 2, 3)
        )
    def forward(self, x):
        x = self.conv1(x)
        x = self.r1(x)
        x = self.conv2(x)
        x = self.r2(x)
        x = self.conv3(x)
        x = self.nets(x)
        return x

old_network = Net_old()
torch.save(old_network.cpu().state_dict(), 't.pth')

pretrained_net = torch.load('t.pth')

# Show keys of pretrained model
for key, v in pretrained_net.items():
    print key
print('****Before loading********')
new_network = Net_new()
print(torch.sum(old_network.nets[0].weight.data))
print(torch.sum(new_network.conv1.weight.data))
print(torch.sum(new_network.nets[0].weight.data))
for key, _ in new_network.state_dict().items():
    print key
print('-----After loading------')
new_network.load_state_dict(pretrained_net, strict=False)
print(torch.sum(old_network.nets[0].weight.data))
print(torch.sum(new_network.conv1.weight.data))
# Hopefully, this value equals to 'old_network.nets[0].weight'
print(torch.sum(new_network.nets[0].weight.data))
for key, _ in new_network.state_dict().items():
    print key
 
 
   
   
   
   
 
 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59

结果:

nets.0.weight
nets.0.bias
nets.2.weight
nets.2.bias
nets.4.weight
nets.4.bias
****Before loading********
-0.197643771768
0.862508803606
1.21658478677
conv1.weight
conv1.bias
conv2.weight
conv2.bias
conv3.weight
conv3.bias
nets.0.weight
nets.0.bias
-----After loading------
-0.197643771768
0.862508803606
-0.197643771768
conv1.weight
conv1.bias
conv2.weight
conv2.bias
conv3.weight
conv3.bias
nets.0.weight
nets.0.bias
 
 
   
   
   
   
 
 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30

发现After loading之后,预期的两个值一致。
总结:用strict=False进行加载模型,则“能塞则塞,不能塞则丢”。load一般是依据key来加载的,一旦有key不匹配则出错。如果设置strict=False,则直接忽略不匹配的key,对于匹配的key则进行正常的赋值。

Strict=False的用途

所以说,当你一个模型训练好之后,你想往里面加几层,那么strict=False可以很容易的加载预训练的参数(注意检查key是否匹配)。只要key能让其进性匹配则可以进行正确的赋值。

出现unexpected key module.xxx.weight问题

有时候你的模型保存时含有 nn.DataParallel时,就会发现所有的dict都会有 module的前缀。
这时候加载含有module前缀的模型时,可能会出错。其实你只要移除这些前缀即可

  pretrained_net = Net_OLD()
  pretrained_net_dict = torch.load(save_path)
  new_state_dict = OrderedDict()
  for k, v in pretrained_net_dict.items():
      name = k[7:] # remove `module.`
      new_state_dict[name] = v
  # load params
  pretrained_net.load_state_dict(new_state_dict)
 
 
   
   
   
   
 
 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

总结

  • 保存的Dict是按照net.属性.weight来存储的。如果这个属性是一个Sequential,我们可以类似这样net.seqConvs.0.weight来获得。
    当然在定义的类中,拿到Sequential的某一层用[], 比如self.seqConvs[0].weight.
  • strict=False是没有那么智能,遵循有相同的key则赋值,否则直接丢弃。

附加

由于第一段的问题还没解决,即如何将Sequential定义的网络的模型参数,加载到用“属性一层层”定义的网络中?
下面是一种比较ugly的方法:

import torch
import torch.nn as nn

class Net_old(nn.Module):
    def __init__(self):
        super(Net_old, self).__init__()
        self.nets = nn.Sequential(
            torch.nn.Conv2d(1, 2, 3),
            torch.nn.ReLU(True),
            torch.nn.Conv2d(2, 1, 3),
            torch.nn.ReLU(True),
            torch.nn.Conv2d(1, 1, 3)
        )
    def forward(self, x):
        return self.nets(x)

class Net_new(nn.Module):
    def __init__(self):
        super(Net_new, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 2, 3)
        self.r1 = torch.nn.ReLU(True)
        self.conv2 = torch.nn.Conv2d(2, 1, 3)
        self.r2 = torch.nn.ReLU(True)
        self.conv3 = torch.nn.Conv2d(1, 1, 3)
    def forward(self, x):
        x = self.conv1(x)
        x = self.r1(x)
        x = self.conv2(x)
        x = self.r2(x)
        x = self.conv3(x)
        x = self.nets(x)
        return x


    def _initialize_weights_from_net(self):
        save_path = 't.pth'
        print('Successfully load model '+save_path)
        # First load the net.
        pretrained_net = Net_old()
        pretrained_net_dict = torch.load(save_path)
        # load params
        pretrained_net.load_state_dict(pretrained_net_dict)

        new_convs = self.get_convs()

        cnt = 0
        # Because sequential is a generator.
                for i, name in enumerate(pretrained_net.nets):
            if isinstance(name, torch.nn.Conv2d):
                print('Assign weight of pretrained model layer : ', name, ' to layer: ', new_convs[cnt])
                new_convs[cnt].weight.data = name.weight.data
                new_convs[cnt].bias.data = name.bias.data
                cnt += 1

    def get_convs(self):
        return [self.conv1, self.conv2, self.conv3]

old_network = Net_old()
torch.save(old_network.cpu().state_dict(), 't.pth')


pretrained_net = torch.load('t.pth')

# Show keys of pretrained model
for key, v in pretrained_net.items():
    print key
print('****Before loading********')
new_network = Net_new()
print(torch.sum(old_network.nets[0].weight.data))
print(torch.sum(new_network.conv1.weight.data))
print('-----New loading method------')
new_network._initialize_weights_from_net()
print(torch.sum(old_network.nets[0].weight.data))
print(torch.sum(new_network.conv1.weight.data))
 
 
   
   
   
   
 
 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74

输出:

nets.0.weight
nets.0.bias
nets.2.weight
nets.2.bias
nets.4.weight
nets.4.bias
****Before loading********
0.510313585401
0.198701560497
-----New loading method------
Successfully load model t.pth
('Assign weight of pretrained model layer : ', Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1)), ' to layer: ', Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1)))
('Assign weight of pretrained model layer : ', Conv2d(2, 1, kernel_size=(3, 3), stride=(1, 1)), ' to layer: ', Conv2d(2, 1, kernel_size=(3, 3), stride=(1, 1)))
('Assign weight of pretrained model layer : ', Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1)), ' to layer: ', Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1)))
0.510313585401
0.510313585401
 
 
   
   
   
   
 
 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

搞定!

以上都是原作者的博客,在此感谢作者的分享,下面给出另一种较为方便的加载模型的方法:

import torch
from collections import OrderedDict
import torch.nn as nn


class Net_old(nn.Module):
    def __init__(self):
        super(Net_old, self).__init__()
        self.nets = nn.Sequential(
            torch.nn.Conv2d(1, 2, 3),
            torch.nn.ReLU(True),
            torch.nn.Conv2d(2, 1, 3),
            torch.nn.ReLU(True),
            torch.nn.Conv2d(1, 1, 3)
        )

    def forward(self, x):
        return self.nets(x)


class Net_new(nn.Module):
    def __init__(self):
        super(Net_new, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 2, 3)
        self.r1 = torch.nn.ReLU(True)
        self.conv2 = torch.nn.Conv2d(2, 1, 3)
        self.r2 = torch.nn.ReLU(True)
        self.conv3 = torch.nn.Conv2d(1, 1, 3)

    def forward(self, x):
        x = self.conv1(x)
        x = self.r1(x)
        x = self.conv2(x)
        x = self.r2(x)
        x = self.conv3(x)
        x = self.nets(x)
        return x


old_network = Net_old()
torch.save(old_network.cpu().state_dict(), 't.pth')
new_network = Net_new()

pretrained_net = torch.load('t.pth')
new_dict = ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'conv3.weight', 'conv3.bias']
new_state_dict = OrderedDict()
for i, (k, v) in enumerate(pretrained_net.items()):
    new_state_dict[new_dict[i]] = v
    
print('****Before loading********')
print(torch.sum(old_network.nets[0].weight.data))
print(torch.sum(new_network.conv1.weight.data))
print('-----New loading method------')
new_network.load_state_dict(new_state_dict)
print(torch.sum(old_network.nets[0].weight.data))
print(torch.sum(new_network.conv1.weight.data))

输出:

****Before loading********
tensor(1.3759)
tensor(0.3301)
-----New loading method------
tensor(1.3759)
tensor(1.3759)

结果完全OK!

你可能感兴趣的:(pytorch,深度学习,pytorch,深度学习,加载模型,load)