TensorFlow遇见Android Studio 在手机上运行ckpt模型

由于项目需要,可能会用到tensorflow和keras的模型,并且需要在App(安卓)上跑模型,实现了这一功能之后,把过程总结一下,分享了出来。

一、TensorFlow模型构建训练

import tensorflow as tf

#由于我的目的就只是把较复杂的模型实现在App,就没有训练,这里实现的是手写数字识别。

def defModel(x): #将所有的图片原本100*100

    #第一个卷积层(100——>50)

    conv=tf.layers.conv2d(

        inputs=x,

        filters=32,

        kernel_size=[5, 5],

        strides=[2,2],

        padding="same",

        activation=tf.nn.relu,

        name='conv')

    pool=tf.layers.max_pooling2d(inputs=conv, pool_size=[2, 2], strides=2,name='pool')

    flatten = tf.layers.flatten(inputs=pool,name='flatten')

    dense= tf.layers.dense(

        inputs=flatten,

        units=8,

        activation='relu',

        name='dense')

    return dense


 

op_data = tf.placeholder(dtype=tf.float32,shape=(None,28,28,1),name='op_data')

op_label = tf.placeholder(dtype=tf.float32,shape=(None,8),name='op_label')

model = defModel(op_data)#print(model)

op_out = tf.nn.softmax(model,name='op_out')

 

#训练中才需要的参数

# # Define loss and optimizer

# loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=op_label, logits=model))#0.9

# train_step = tf.train.RMSPropOptimizer(learning_rate=0.001).minimize(loss)

 

# #define accuracy on labels

# correct_prediction = tf.equal(tf.argmax(op_label, 1), tf.argmax(model,1))

# accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


 

#把当前的模型保存为PB文件,PB文件会保存当前tensorflow的模型,将其他值固化为常量

from tensorflow.python.framework import graph_util

saver = tf.train.Saver()

with tf.Session() as sess: #开始一个会话

    sess.run(tf.global_variables_initializer())

# 第一个参数 sess指定为当前的Session

# 第二个参数是要保存的 图的定义,默认是当前图

# 然后是要输出的节点

    const_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=['op_data','op_out'])

    with tf.gfile.FastGFile('cnn.pb', mode='wb') as f: # 这里是选择要保存的位置

        f.write(const_graph.SerializeToString())

    print('save pb ok!')

接下来使用PB查看函数查看节点信息

def check_pb(pb_path='model.pb'):

#输出保存的模型中参数名字及对应的值

with tf.gfile.GFile(pb_path, "rb") as f:

#读取模型数据

graph_def = tf.GraphDef()

graph_def.ParseFromString(f.read())

#得到模型中的计算图和数据

with tf.Graph().as_default() as graph:

# 这里的Graph()要有括号,不然会报TypeError

tf.import_graph_def(graph_def, name="")

# #导入模型中的图到现在这个新的计算图中,不指定名字的话默认是import

 

for op in graph.get_operations():

# 打印出图中的节点信息

print(op.name, op.values())

 

'''

op_data (,)

op_label (,)

op_out/dimension (,)

op_out (,)

'''

通过以上程序可以看出我们主要的模型已经搭建完成并且获得输入输出的节点名称,这在App端调用会用到。

二、在App端实现PBTf  Java类

package com.example.tan.tfmodel;

import android.content.res.AssetManager;

import org.tensorflow.contrib.android.TensorFlowInferenceInterface;

public class PbTf {
    //模型中输入、输出变量的名称
    private static final String inputName = "op_data";
    private static final String outputName = "op_out";
    String[] outputNames = new String[] {outputName};

    TensorFlowInferenceInterface tfInfer;
    static {//加载libtensorflow_inference.so库文件
        System.loadLibrary("tensorflow_inference");
    }
    PbTf(AssetManager assetManager, String modePath) {
        //初始化TensorFlowInferenceInterface对象
        tfInfer = new TensorFlowInferenceInterface(assetManager,modePath);
    }


    public int getPredict() {
        float[] inputs = new float[784];
        for(int i=0;i

TensorFlow遇见Android Studio 在手机上运行ckpt模型_第1张图片

你可能感兴趣的:(程序员)