关于导入vgg16bn预训练模型失败

pytorch 0.4.1
mod = models.vgg16_bn(pretrained=True)
self._initialize_weights()
#print(len(self.frontend.state_dict().items()))
#print(len(mod.state_dict().items()))
for i in xrange(len(self.frontend.state_dict().items())):
    xx = self.frontend.state_dict().items()[i][0]
    if "num_batches_tracked" in xx:
        continue
    self.frontend.state_dict().items()[i][1].data[:] = mod.state_dict().items()[i][1].data[:]

因:pytorch 0.4版本中 在BN引入了num_batches_tracked

mod.state_dict().items()中,一层conv包含两种参数,一层BN包含4种参数。Pooling层不含参数

暂时采取这个解决办法

if "num_batches_tracked" in xx:
  continue

因为,通过观察得知VGG16_BN网络,BN参数中的num_batches_tracked,为tensor(0),与待转入的网络初始值一样。(都为tensor(0),故我选择跳过该参数传递)。

待补充别人的VGG16_BN预训练模型导入:

https://blog.csdn.net/u012494820/article/details/79068625

model = ...
model_dict = model.state_dict()

# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)

改写:
model = ...
model_dict = model.state_dict()

# 1. filter out unnecessary keys
temp = {}
for k,v in pretrained_dict.items():
    if k in model_dict:
        temp[k]=v
pretrained_dict = temp
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
        

 

你可能感兴趣的:(深度学习)