python 读取tensorflow模型的参数,并重新写入.mat文件

因为现有工作一部分在python,一部分在C++,需要将tensorflow已经训练好的模型应用到C++程序中,选择比较笨的折中办法,先将模型参数写入mat文件,然后再用C++读取mat文件。

读取模型数据并存储为.mat文件

我训练的模型就是tf-MNIST,tensorflow的官方例子。

关于tensorflow模型的保存和读取可看这篇文章。

import numpy as np
import tensorflow as tf
import scipy.io as sio


if __name__ == "__main__":
    with tf.Session() as sess:
        # load the meta graph and weights
        saver = tf.train.import_meta_graph('model_2\minist.ckpt-70.meta')
        saver.restore(sess, tf.train.latest_checkpoint('model_2/'))

        # get weighs
        graph = tf.get_default_graph()
        conv1_w = sess.run(graph.get_tensor_by_name('conv1/w:0'))
        sio.savemat("weights/conv1_w.mat", {"array": conv1_w})
        conv1_b = sess.run(graph.get_tensor_by_name('conv1/b:0'))
        sio.savemat("weights/conv1_b.mat", {"array": conv1_b})
        conv2_w = sess.run(graph.get_tensor_by_name('conv2/w:0'))
        sio.savemat("weights/conv2_w.mat", {"array": conv2_w})
        conv2_b = sess.run(graph.get_tensor_by_name('conv2/b:0'))
        sio.savemat("weights/conv2_b.mat", {"array": conv2_b})

        fc1_w = sess.run(graph.get_tensor_by_name('fc1/w:0'))
        sio.savemat("weights/fc1_w.mat", {"array": fc1_w})
        fc1_b = sess.run(graph.get_tensor_by_name('fc1/b:0'))
        sio.savemat("weights/fc1_b.mat", {"array": fc1_b})

        fc2_w = sess.run(graph.get_tensor_by_name('fc2/w:0'))
        sio.savemat("weights/fc2_w.mat", {"array": fc2_w})
        fc2_b = sess.run(graph.get_tensor_by_name('fc2/b:0'))
        sio.savemat("weights/fc2_b.mat", {"array": fc2_b})


你可能感兴趣的:(Piecemeal)