由于项目需要,可能会用到tensorflow和keras的模型,并且需要在App(安卓)上跑模型,实现了这一功能之后,把过程总结一下,分享了出来。
一、在PC端训练Keras模型
下面给出我的实例代码
from __future__ import print_function
import keras
from keras.models import Sequential,save_model,load_model
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K
import numpy as np
data = np.random.rand(9,28,28,1)
label = np.random.rand(9,8)
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
activation='relu',padding='same',
input_shape=(28,28,1),name='op_data'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(32, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(8, activation='softmax',name='op_out'))
model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.RMSprop(),
metrics=['accuracy'])
model.fit(data,label,epochs=3)
二、保存Keras模型为H5文件
h5_path = 'data/keras.h5'
# model.save(h5_path)
save_model(model,h5_path)
newmodel = load_model(h5_path)
model.summary()
'''
op_data (Conv2D) (None, 28, 28, 32) 320
op_out (Dense) (None, 8) 264
'''
二、保存Keras模型为PB文件
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.python.framework import graph_io
from keras import backend as K
from keras.models import load_model
output_folder = 'data'
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=False):
"""
Freezes the state of a session into a prunned computation graph.
Creates a new computation graph where variable nodes are replaced by
constants taking their current value in the session. The new graph will be
prunned so subgraphs that are not neccesary to compute the requested
outputs are removed.
@param session The TensorFlow session to be frozen.
@param keep_var_names A list of variable names that should not be frozen,
or None to freeze all the variables in the graph.
@param output_names Names of the relevant graph outputs.
@param clear_devices Remove the device directives from the graph for better portability.
@return The frozen graph definition.
"""
from tensorflow.python.framework.graph_util import convert_variables_to_constants
graph = session.graph
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.global_variables()]
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = convert_variables_to_constants(session, input_graph_def,
output_names, freeze_var_names)
return frozen_graph
def keras2pb(h5_path='keras.h5', out_folder='data', out_pb ='data.pb'):
if not os.path.isdir(out_folder):
os.mkdir(out_folder)
K.set_learning_phase(0)
keras_model = load_model(h5_path)
print('in layer:', keras_model.input.name)
print('out layer:', keras_model.output.name)
with K.get_session() as sess:
frozen_graph = freeze_session(sess, output_names=[keras_model.output.op.name])
graph_io.write_graph(frozen_graph, out_folder, out_pb, as_text=False)
print('save keras model as pb file at: ', osp.join(out_folder, out_pb))
三、调用Keras PB文件,实现Java PbKeras 类
package com.example.tan.tfmodel;
import android.content.Context;
import android.content.res.AssetManager;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
public class PbKeras {
//模型中输入、输出变量的名称
private static final String inputName = "op_data_input";
private static final String outputName = "op_out/Softmax";
String[] outputNames = new String[] {outputName};
TensorFlowInferenceInterface tfInfer;
static {//加载libtensorflow_inference.so库文件
System.loadLibrary("tensorflow_inference");
}
PbKeras(AssetManager assetManager, String modePath) {
//初始化TensorFlowInferenceInterface对象
tfInfer = new TensorFlowInferenceInterface(assetManager,modePath);
}
public int getPredict() {
float[] inputs = new float[784];
for(int i=0;i
四、在Activity中调用PBKeras,实现演示功能
package com.example.tan.tfmodel;
import android.support.v7.app.AppCompatActivity;
import android.os.Bundle;
import android.view.View;
import android.widget.Button;
import android.widget.TextView;
public class ActPb extends AppCompatActivity {
TextView txt;
PbTf pbTf;
PbKeras pbKeras;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_pbtf);
pbTf = new PbTf(getAssets(),"cnn.pb");
pbKeras = new PbKeras(getAssets(),"keras.pb");
txt = findViewById(R.id.txt);
Button btn_tf = findViewById(R.id.btn_tf);
btn_tf.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View view) {
int index = pbTf.getPredict();
txt.setText("pbTf 识别结果:"+String.valueOf(index));
}
});
Button btn_keras = findViewById(R.id.btn_keras);
btn_keras.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View view) {
int index = pbKeras.getPredict();
txt.setText("pbKeras 识别结果:"+String.valueOf(index));
}
});
}
}
五、截图示意