python&C++ python读取tensorflow模型参数,写入txt文件,C++读取txt文件

发现用C++读.mat文件需要matlab的依赖,本机没有装matlab,只能用最朴素的方式来存储数据核读入数据了。

1.讲tensorflow的参数存入txt文件

存的模型与上一篇博客一致,只是这次改用txt。numpy自带的写入文件函数支持一维数组和二维数组的写入,但是卷积核这个是四维的,且shape为[卷积核高,卷积核宽,输入通道数,输出通道数],在这里我把shape转为[输出通道数,输入通道数,卷积核高,卷积核宽]存储,并将数据放大了1000倍,存为整数。

代码如下:

import numpy as np
import tensorflow as tf

def store_4d_array(kernel, filename):
    # store the kernel
    f = open(filename, 'w+')
    shape = kernel.shape
    num_out_channel = shape[3]
    num_in_channel = shape[2]
    num_width = shape[0]
    f.write(str(num_out_channel) + ',' + str(num_in_channel) + ',' + str(num_width) + ',' + str(num_width) + '\n')
    for index_out_channel in range(num_out_channel):
        for index_in_channel in range(num_in_channel):
            for index_row in range(num_width):
                for index_col in range(num_width):
                    f.write(str(int(kernel[index_row][index_col][index_in_channel][index_out_channel] * 1000)))
                    if index_col == num_width - 1:
                        f.write('\n')
                    else:
                        f.write(',')
    f.close()

def store_1d_2d_array(bias, filename):
    # store the bias
    bias = bias * 1000
    bias = bias.astype(int)
    np.savetxt(filename, bias, delimiter=',', fmt="%d")

if __name__ == "__main__":
    with tf.Session() as sess:
        # load the meta graph and weights
        saver = tf.train.import_meta_graph('model_2\minist.ckpt-70.meta')
        saver.restore(sess, tf.train.latest_checkpoint('model_2/'))

        # get weighs
        graph = tf.get_default_graph()
        conv1_w = sess.run(graph.get_tensor_by_name('conv1/w:0'))
        np.save("conv1_w", conv1_w)
        store_4d_array(conv1_w, "weights/conv1_w.txt")
        conv1_b = sess.run(graph.get_tensor_by_name('conv1/b:0'))
        store_1d_2d_array(conv1_b, "weights/conv1_b.txt")
        conv2_w = sess.run(graph.get_tensor_by_name('conv2/w:0'))
        store_4d_array(conv2_w, "weights/conv2_w.txt")
        conv2_b = sess.run(graph.get_tensor_by_name('conv2/b:0'))
        store_1d_2d_array(conv2_b, "weights/conv2_b.txt")

        fc1_w = sess.run(graph.get_tensor_by_name('fc1/w:0'))
        store_1d_2d_array(fc1_w, "weights/fc1_w.txt")
        fc1_b = sess.run(graph.get_tensor_by_name('fc1/b:0'))
        store_1d_2d_array(fc1_b, "weights/fc1_b.txt")

        fc2_w = sess.run(graph.get_tensor_by_name('fc2/w:0'))
        store_1d_2d_array(fc2_w, "weights/fc2_w.txt")
        fc2_b = sess.run(graph.get_tensor_by_name('fc2/b:0'))
        store_1d_2d_array(fc2_b, "weights/fc2_b.txt")

2.C++读取txt中的数据

上述过程的逆过程,因为都是按行存储,所以思路很简单,上代码:

#include 
#include 
#include 
#include 
#include 

using namespace std;

void read_4d_array(string filename, vector>>>& kernel)
{
	// 读取卷积核
	// 读文件
	ifstream infile(filename, ios::in);
	string line;
	getline(infile, line);
	cout << "卷积核结构:" << line << endl;
	// 存储卷积核的信息
	stringstream ss_line(line);
	string str;
	vector line_array;
	// 按照逗号分割,存成整型的数据
	while (getline(ss_line, str, ','))
	{
		stringstream str_temp(str);
		int int_temp;
		str_temp >> int_temp;
		line_array.push_back(int_temp);
	}
	int num_out_channel = line_array[0];
	int	num_in_channel = line_array[1];
	int	num_width = line_array[2];

	// 逐行读取文件的信息,并存储
	for (int index_out_channel = 0; index_out_channel < num_out_channel; index_out_channel++)
	{
		// 用来存储一个in_channel 卷积核
		vector>> one_in_kernel;
		for (int index_in_channel = 0; index_in_channel < num_in_channel; index_in_channel++)
		{
			// 用来存储一个二维的卷积核
			vector> one_kernel;
			for (int index_row = 0; index_row < num_width; index_row++)
			{
				getline(infile, line);
				stringstream tmp_line(line);
				// 用来存储卷积核的一行
				vector tmp_int_line;
				while (getline(tmp_line, str, ','))
				{
					stringstream str_temp(str);
					int int_temp;
					str_temp >> int_temp;
					tmp_int_line.push_back(int_temp);
				}
				one_kernel.push_back(tmp_int_line);
			}
			one_in_kernel.push_back(one_kernel);
		}
		kernel.push_back(one_in_kernel);
	}
}
void read_1d_array(string filename, vector& bias)
{
	// 读取偏置项,偏置项一个数字占一行
	// 打开文件
	ifstream infile(filename, ios::in);
	string line;
	while (getline(infile, line))
	{
		stringstream tmp_line(line);
		int int_tmp;
		tmp_line >> int_tmp;
		bias.push_back(int_tmp);
	}
}
void read_2d_array(string filename, vector>& weights)
{
	// 读取二维存储的文件,主要是矩阵乘的权重向量、
	ifstream infile(filename, ios::in);
	string line;
	while (getline(infile, line))
	{
		stringstream ss_line(line);
		string str;
		vector line_array;
		// 按照逗号分割,存成整型的数据
		while (getline(ss_line, str, ','))
		{
			stringstream str_temp(str);
			int int_temp;
			str_temp >> int_temp;
			line_array.push_back(int_temp);
		}
		weights.push_back(line_array);
	}
}

你可能感兴趣的:(Piecemeal)