#!/usr/bin/env python2.7
#coding=utf-8
import caffe
import csv
import numpy as np
# np.set_printoptions(threshold='nan')
MODEL_FILE = 'inception_v3_rgb_deploy.prototxt'
PRETRAIN_FILE = 'inception_v3_kinetics_rgb_pretrained.caffemodel'
net = caffe.Net(MODEL_FILE, PRETRAIN_FILE, caffe.TEST)
p = []
for param_name in net.params.keys():
# print(param_name)
weight = net.params[param_name][0].data
bias = net.params[param_name][1].data
p.append(weight)
p.append(bias)
np.save('params.npy', p)
存为numpy之前,要分析网络架构(层名称等)是否对应。
#!/usr/bin/env python3.5
# load net :
"""
https://github.com/pytorch/vision/blob/master/torchvision/models/inception.py
caffemodel has pretrained conv.bias and torchvision doesn't, so change L35, L321
35 def __init__(self, num_classes=400, aux_logits=False, transform_input=False):
320 super(BasicConv2d, self).__init__()
321 self.conv = nn.Conv2d(in_channels, out_channels, bias=True, **kwargs)
"""
net1 = Inception3()
from collections import OrderedDict
from torch import Tensor
new_params = np.load('params.npy',encoding='bytes')
dict_new = net1.state_dict().copy() # isinstance(dict_new, OrderedDict) is True
new_list = list(net1.state_dict().keys())
ind_new = 0
for i, k in enumerate(new_list):
# print(i, k, ind_new, new_list[i])
# use this when num_class != 400 , it won't load fc's weight and bias
# if k.split('.')[-1] not in ['weight', 'bias'] or k.split('.')[-2] == 'fc':
if k.split('.')[-1] not in ['weight', 'bias']:
continue
# check shape
tmp = new_params[ind_new].reshape((net1.state_dict()[new_list[i]]).shape)
dict_new[ new_list[i] ] = Tensor(tmp)
# print(Tensor(new_params[ind_new]).shape)
ind_new += 1
net1.load_state_dict(dict_new)
torch.save(net1.state_dict(), 'pretrain_v3_params.pkl')
print('saved done.')
# use pretrained params
net2 = Inception3()
net2.load_state_dict(torch.load('pretrain_v3_params.pkl'))
print('loaded done.')
第一段代码用Python2.7的caffe进行提取保存权重,第二段代码用python3.5的Pytorch进行加载。
以下是坑:
参考:
https://blog.csdn.net/u011762313/article/details/49851795
https://zhuanlan.zhihu.com/p/34147880
# BN_inc
from collections import OrderedDict
from torch import Tensor
import numpy as np
new_params = np.load('/home/yaotiechui/pytorch-caffe/bn_params22.npy',encoding='bytes')
dict_new = net1.state_dict().copy() # isinstance(dict_new, OrderedDict) is True
new_list = list(net1.state_dict().keys())
ind_new
= 0
for i, k in enumerate(new_list):
# print(i, k, ind_new, new_list[i])
# inception_5b_pool_proj_bn.num_batches_tracked == [], Tensor(0)
print(new_list[i],net1.state_dict()[new_list[i]].dim(), )
if net1.state_dict()[new_list[i]].dim() == 0:
break
else:
print(net1.state_dict()[new_list[i]].shape, new_params[ind_new].shape)
tmp = new_params[ind_new].reshape((net1.state_dict()[new_list[i]]).shape)
dict_new[ new_list[i] ] = Tensor(tmp)
print('new:', dict_new[ new_list[i] ].shape)
ind_new += 1
net1.load_state_dict(dict_new)
torch.save(net1.state_dict(), 'pretrain_bn_params.pkl')
print('saved done.')
# use pretrained params
net2 = bninception()
net2.last_linear = nn.Linear(1024, 400)
net2.load_state_dict(torch.load('pretrain_bn_params.pkl'))
print('loaded done.')
# can be changed like
def load_pre_model_dict(self, state_dict):
own_state = self.state_dict()
for name, param in state_dict.items():
if name[:6] == "module":
name = '.'.join(name.split('.')[2:])
# print('name,',name)
if name not in own_state:
continue
print('load....', name)
if isinstance(param, nn.Parameter):
print('true')
# backwards compatibility for serialized parameters
param = param.data
own_state[name].copy_(param)