caffe库的编译可以参考我之前写的一篇博客:ImportError: dynamic module does not define module export function (PyInit__caffe)问题解决记录_chen_zn95的博客-CSDN博客
安装好后使用以下脚本便可将caffe模型的参数名和参数保存成dict,
import pickle as pkl
import caffe
MODEL_FILE = 'xxx.prototxt'
PRETRAIN_FILE = 'xxx.caffemodel'
if __name__ == '__main__':
net = caffe.Net(MODEL_FILE, PRETRAIN_FILE, caffe.TEST)
name_weights = {}
for param_name in net.params.keys():
name_weights[param_name] = {}
layer_params = net.params[param_name]
if len(layer_params) == 1:
weight = layer_params[0].data
name_weights[param_name]['weight'] = weight
print('%s:\n\t%s (weight)' % (param_name, weight.shape))
elif len(layer_params) == 2:
# weight
weight = layer_params[0].data
name_weights[param_name]['weight'] = weight
# bias
bias = layer_params[1].data
name_weights[param_name]['bias'] = bias
print('%s:\n\t%s (weight)' % (param_name, weight.shape))
print('\t%s (bias)' % str(bias.shape))
elif len(layer_params) == 3:
# BN: running_mean, running_var, scale factor
running_mean = layer_params[0].data # running_mean
name_weights[param_name]['running_mean'] = running_mean / layer_params[2].data
running_var = layer_params[1].data # running_var
name_weights[param_name]['running_var'] = running_var/layer_params[2].data
print('%s:\n\t%s (running_var)' % (param_name, running_var.shape),)
print('\t%s (running_mean)' % str(running_mean.shape))
else:
raise RuntimeError("error\n")
# save weight
with open('weights.pkl', 'wb') as f:
pkl.dump(name_weights, f, protocol=2)
这里有两个思路,一是根据权重名来匹配,二是根据权重的shape来匹配,但第二个方法有个问题,就是如果网络中有两个以上shape一样的层的话,那么根据权重的shape来匹配就会出错。下面分别介绍一下以上两个思路,
这个方法比较繁琐,要求pytorch模型的参数名要与caffe模型的保持一致,如果不一致,则需要自己写个dict进行映射。具体操作如下,
import pickle as pkl
import torch
import copy
model = xxx
model1 = copy.deepcopy(model)
state_dict = {}
with open("weights.pkl", "rb") as wp: # weights.pkl: 步骤一中生成的dict
name_weights = pkl.load(wp, encoding='iso-8859-1')
for key, value in name_weights.items():
for k, v in value.items():
state_dict[key + "." + k] = torch.from_numpy(v)
model1.load_state_dict(state_dict, strict=True)
另一种实现是直接对pytorch模型的参数赋值,代码如下,
import pickle as pkl
import torch
import copy
model = xxx
model2 = copy.deepcopy(model)
with open("weights.pkl", "rb") as wp:
name_weights = pkl.load(wp, encoding='iso-8859-1')
for name, param in model2.named_parameters():
for key, value in name_weights.items():
if name.split(".")[0] == key:
for k, v in value.items():
if name.split(".")[1] == k:
param.data = torch.from_numpy(v)
import pickle as pkl
import torch
import copy
model = LightCNN_ir_eye()
model3 = copy.deepcopy(model)
with open("weights.pkl", "rb") as wp:
name_weights = pkl.load(wp, encoding='iso-8859-1')
for name, param in model3.named_parameters():
for key, value in name_weights.items():
for k, v in value.items():
v = torch.from_numpy(v)
if param.data.shape == v.shape:
if name == key + "." + k: # 防止多个权重shape一致导致的错误
param.data = v
import cv2
import numpy as np
import torch
img = cv2.imread("xxx.jpg")
img = cv2.resize(img, (width, height))
img = np.tranpose(img, (2,0,1))
img = np.expand_dims(img, axis=0)
out1 = model1(torch.from_numpy(img).float())
out2 = model2(torch.from_numpy(img).float())
out3 = model3(torch.from_numpy(img).float())
print(out1)
print(out2)
print(out3)
for i in range(len(out1)):
print(out1[i] == out2[i])
print(out1[i] == out3[i])