加载图结构和参数:
import tensorflow as tf
ckpt = tf.train.get_checkpoint_state('./model/') #通过检查点文件锁定最新的模型
saver = tf.train.import_meta_graph(file_path[0]) #载入图结构,存在在.meta文件中
with tf.Session() as sess:
saver.restore(sess,file_path[1]) #载入参数,参数保存在两个文件中
参考链接
import numpy as np
pre_train = np.load("pnet.npy",allow_pickle = True,encoding = 'latin1')
print(pre_train)
参考链接
直接附代码:
import sys
import tensorflow as tf
import caffe
import numpy as np
import cv2
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
from tensorflow.python import pywrap_tensorflow
checkpoint_path = "./pnet.npy" # .npy模型文件的存放地址
pre_train = np.load(checkpoint_path, allow_pickle=True, encoding="latin1")
dict = pre_train.item()
dict1 = dict.items()
cf_prototxt = "./det1.prototxt" # caffe模型的网络结构
cf_model = "./det1.caffemodel" # 要生成的caffe模型
def tensor4d_transform(tensor):
return tensor.transpose((3, 2, 0, 1))
def tensor2d_transform(tensor):
return tensor.transpose((1, 0))
def tf2caffe(checkpoint_path,cf_prototxt,cf_model):
net = caffe.Net(cf_prototxt, caffe.TRAIN)
for key1, value1 in dict1:
if "PReLU" in key1:
pass
else:
dict2 = value1.items()
for key2, value2 in dict2:
if key2 == "weights":
net.params[key1][0].data.flat = value2.flat
elif key2 =='biases':
net.params[key1][1].data.flat = value2.flat
else:
pass
net.save(cf_model)
print("\n- Finished.\n")
if __name__ == '__main__':
tf2caffe(checkpoint_path, cf_prototxt, cf_model)
参考
我python3转换mtcnn的模型试了一下,各种报错,解决了一个还有一个,最后实在试不下去,放弃了,最后卡在这里:
Type Name Param Output
----------------------------------------------------------------------------------------------
Input data -- (1, 3, 12, 12)
Convolution conv1 (10, 3, 3, 3) (1, 10, 10, 10)
Pooling pool1 -- (1, 10, 5, 5)
Convolution conv2 (16, 10, 3, 3) (1, 16, 3, 3)
Convolution conv3 (32, 16, 3, 3) (1, 32, 1, 1)
Convolution conv4-1 (2, 32, 1, 1) (1, 2, 1, 1)
Convolution conv4-2 (4, 32, 1, 1) (1, 4, 1, 1)
Softmax prob1 -- (1, 2, 1, 1)
Converting data...
Saving data...
Traceback (most recent call last):
File "convert.py", line 61, in <module>
main()
File "convert.py", line 57, in main
args.phase)
File "convert.py", line 34, in convert
np.save(data_out, data)
File "/home/yaspeed/.local/lib/python3.6/site-packages/numpy/lib/npyio.py", line 536, in save
pickle_kwargs=pickle_kwargs)
File "/home/yaspeed/.local/lib/python3.6/site-packages/numpy/lib/format.py", line 633, in write_array
pickle.dump(array, fp, protocol=2, **pickle_kwargs)
AttributeError: Can't pickle local object 'DataInjector.load_using_caffe.<locals>.<lambda>'