TensorFlow 遇见 Android Studio 在手机上运行keras的H5模型

由于项目需要,可能会用到tensorflow和keras的模型,并且需要在App(安卓)上跑模型,实现了这一功能之后,把过程总结一下,分享了出来。

一、在PC端训练Keras模型

下面给出我的实例代码

from __future__ import print_function

import keras

from keras.models import Sequential,save_model,load_model

from keras.layers import Dense, Dropout, Flatten

from keras.layers import Conv2D, MaxPooling2D

from keras import backend as K

 

import numpy as np

data = np.random.rand(9,28,28,1)

label = np.random.rand(9,8)

 

model = Sequential()

model.add(Conv2D(32, kernel_size=(3, 3),

activation='relu',padding='same',

input_shape=(28,28,1),name='op_data'))

model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Dropout(0.25))

model.add(Flatten())

model.add(Dense(32, activation='relu'))

model.add(Dropout(0.5))

model.add(Dense(8, activation='softmax',name='op_out'))

 

model.compile(loss=keras.losses.categorical_crossentropy,

optimizer=keras.optimizers.RMSprop(),

metrics=['accuracy'])

model.fit(data,label,epochs=3)

 

二、保存Keras模型为H5文件

h5_path = 'data/keras.h5'

# model.save(h5_path)

save_model(model,h5_path)

newmodel = load_model(h5_path)

model.summary()

'''

op_data (Conv2D) (None, 28, 28, 32) 320

op_out (Dense) (None, 8) 264

'''

二、保存Keras模型为PB文件

 

import tensorflow as tf

import tensorflow.keras as keras

from tensorflow.python.framework import graph_io

from keras import backend as K

from keras.models import load_model

 

output_folder = 'data'

 

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=False):

"""

Freezes the state of a session into a prunned computation graph.

Creates a new computation graph where variable nodes are replaced by

constants taking their current value in the session. The new graph will be

prunned so subgraphs that are not neccesary to compute the requested

outputs are removed.

@param session The TensorFlow session to be frozen.

@param keep_var_names A list of variable names that should not be frozen,

or None to freeze all the variables in the graph.

@param output_names Names of the relevant graph outputs.

@param clear_devices Remove the device directives from the graph for better portability.

@return The frozen graph definition.

"""

from tensorflow.python.framework.graph_util import convert_variables_to_constants

graph = session.graph

with graph.as_default():

freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))

output_names = output_names or []

output_names += [v.op.name for v in tf.global_variables()]

input_graph_def = graph.as_graph_def()

if clear_devices:

for node in input_graph_def.node:

node.device = ""

frozen_graph = convert_variables_to_constants(session, input_graph_def,

output_names, freeze_var_names)

return frozen_graph

 

def keras2pb(h5_path='keras.h5', out_folder='data', out_pb ='data.pb'):

if not os.path.isdir(out_folder):

os.mkdir(out_folder)

K.set_learning_phase(0)

keras_model = load_model(h5_path)

 

print('in layer:', keras_model.input.name)

print('out layer:', keras_model.output.name)

 

with K.get_session() as sess:

frozen_graph = freeze_session(sess, output_names=[keras_model.output.op.name])

graph_io.write_graph(frozen_graph, out_folder, out_pb, as_text=False)

print('save keras model as pb file at: ', osp.join(out_folder, out_pb))

 

三、调用Keras PB文件,实现Java PbKeras 类

package com.example.tan.tfmodel;

import android.content.Context;
import android.content.res.AssetManager;

import org.tensorflow.contrib.android.TensorFlowInferenceInterface;

public class PbKeras {
    //模型中输入、输出变量的名称
    private static final String inputName = "op_data_input";
    private static final String outputName = "op_out/Softmax";
    String[] outputNames = new String[] {outputName};

    TensorFlowInferenceInterface tfInfer;
    static {//加载libtensorflow_inference.so库文件
        System.loadLibrary("tensorflow_inference");
    }
    PbKeras(AssetManager assetManager, String modePath) {
        //初始化TensorFlowInferenceInterface对象
        tfInfer = new TensorFlowInferenceInterface(assetManager,modePath);
    }


    public int getPredict() {
        float[] inputs = new float[784];
        for(int i=0;i

四、在Activity中调用PBKeras,实现演示功能

package com.example.tan.tfmodel;

import android.support.v7.app.AppCompatActivity;
import android.os.Bundle;
import android.view.View;
import android.widget.Button;
import android.widget.TextView;

public class ActPb extends AppCompatActivity {
    TextView txt;
    PbTf pbTf;
    PbKeras pbKeras;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_pbtf);

        pbTf = new PbTf(getAssets(),"cnn.pb");
        pbKeras = new PbKeras(getAssets(),"keras.pb");
        txt = findViewById(R.id.txt);


        Button btn_tf = findViewById(R.id.btn_tf);
        btn_tf.setOnClickListener(new View.OnClickListener() {
            @Override
            public void onClick(View view) {
                int index = pbTf.getPredict();
                txt.setText("pbTf 识别结果:"+String.valueOf(index));
            }
        });

        Button btn_keras = findViewById(R.id.btn_keras);
        btn_keras.setOnClickListener(new View.OnClickListener() {
            @Override
            public void onClick(View view) {
                int index = pbKeras.getPredict();
                txt.setText("pbKeras 识别结果:"+String.valueOf(index));
            }
        });
    }
}

 

五、截图示意

TensorFlow 遇见 Android Studio 在手机上运行keras的H5模型_第1张图片

TensorFlow 遇见 Android Studio 在手机上运行keras的H5模型_第2张图片

你可能感兴趣的:(程序员)