该文章纯属转载
概述 随着深度学习的广泛应用和Tensoflow的开源,移动端的模型应用层出不穷。本文介绍了笔者在搭建过程中的一些心得,希望可以帮助到你们。 Mac端Tensorflow CPU版本的安装 如果你现在用的没有太好的GPU,可以安装CPU only的Tensorflow。Linux、Mac系统可以安装Tensorflow的python2和python3版本,Windows系统仅支持python3版本。 安装Tensorflow的依赖库bazel,这个后面要用来生成Tensorflow支持Android的jar包和so库。Mac下用brew安装命令:brew install bazel
,或者根据bazel官方文档安装相应的版本; 用pip安装tensorflow的CPU only版本:pip install tensorflow
; 验证Tensorflow安装是否成功:
import tensorflow as tf
#显示当前Tensorflow版本号
tf.__version__
生成jar包和so库 将github上的tensorflow下载到本地 修改tensorflow目录下的WORKSPACE,将其中的sdk和ndk路径改为本地对应路径,其中的sdk的版本号要≥23,ndk的版本号建议是12b(高版本的ndk在用bazel编译时会出现一些问题),build_tools_version版本根据自己实际情况更改:
根据下面的指令生成jar包和so库
参考链接:
https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/android
生成so库的命令:可以选择cpu版本号
bazel build -c opt //tensorflow/contrib/android:libtensorflow_inference.so \
--crosstool_top=//external:android/crosstool \
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain\
--cpu=armeabi-v7a
so库的位置:
bazel-bin/tensorflow/contrib/android/libtensorflow_inference.so
生成jar包的命令:
bazel build //tensorflow/contrib/android:android_tensorflow_inference_java
jar包的位置:
bazel-bin/tensorflow/contrib/android/libandroid_tensorflow_inference_java.jar
4. so库也可以选择Tensorflow官方提供的现成的文件,参考链接:
http://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/libtensorflow_inference.so/
Android端的搭建 将jar包放入app->libs目录下,并在build.gradle
上添加依赖compile files('libs/libandroid_tensorflow_inference_java.jar')
; 在src->main目录下新建文件夹jniLibs并将生成的so库放在该目录下; 将PC端训练的模型保存为pb模型
output_graph_def= \
graph_util.convert_variables_to_constants(sess, \
sess.graph_def,output_node_names=['output'])
with tf.gfile.FastGFile("path/to/xxx.pb","wb") as f:
f.write(output_graph_def.SerializeToString())
4.将pb文件放在src->main->assets目录下;
5. 下面我们用Android代码解释一下Tensorflow在Android端的搭建
public class TensorFlowAudioClassifier implements Classifier{
private static final String TAG = "TensorFlowAudioClassifier";
// Only return this many results with at least this confidence.
private static final int MAX_RESULTS = 3;
private static final float THRESHOLD = 0.0f;
// Config values.
//输入节点的名称(不带后面的':0',只是input的名称,如'input')
private String inputName;
//输出节点的名称(通输入节点名称一样)
private String outputName;
//输入矩阵的大小(因为一般是方形矩阵,这儿是方形矩阵的size)
private int inputSize;
// Pre-allocated buffers.
private Vector labels = new Vector();
private float[] floatValues;
private float[] outputs;
private String[] outputNames;
private TensorFlowInferenceInterface inferenceInterface;
//这儿使用单例模式
private TensorFlowAudioClassifier() {
}
/**
* Initializes a native TensorFlow session for classifying images.
*
* @param assetManager The asset manager to be used to load assets.
* @param modelFilename The filepath of the model GraphDef protocol buffer.
* @param labelFilename The filepath of label file for classes.
* @param inputSize The input size. A square image of inputSize x inputSize is assumed.
* @param inputName The label of the image input node.
* @param outputName The label of the output node.
* @throws IOException
*/
public static Classifier create(
AssetManager assetManager,
String modelFilename,
String labelFilename,
int inputSize,
String inputName,
String outputName)
throws IOException {
TensorFlowAudioClassifier c = new TensorFlowAudioClassifier();
c.inputName = inputName;
c.outputName = outputName;
// Read the label names into memory.
// TODO(andrewharp): make this handle non-assets.
//获取label文件,后面可以用来构建bean
String actualFilename = labelFilename.split("file:///android_asset/")[1];
Log.i(TAG, "Reading labels from: " + actualFilename);
BufferedReader br = null;
br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename)));
String line;
while ((line = br.readLine()) != null) {
c.labels.add(line);
}
br.close();
c.inferenceInterface = new TensorFlowInferenceInterface();
if (c.inferenceInterface.initializeTensorFlow(assetManager, modelFilename) != 0) {
throw new RuntimeException("TF initialization failed");
}
// The shape of the output is [N, NUM_CLASSES], where N is the batch size.
int numClasses =
(int) c.inferenceInterface.graph().operation(outputName).output(0).shape().size(1);
Log.i(TAG, "Read " + c.labels.size() + " labels, output layer size is " + numClasses);
// Ideally, inputSize could have been retrieved from the shape of the input operation. Alas,
// the placeholder node for input in the graphdef typically used does not specify a shape, so it
// must be passed in as a parameter.
c.inputSize = inputSize;
// Pre-allocate buffers.
c.outputNames = new String[]{outputName};
c.floatValues = new float[inputSize * inputSize * 1];
c.outputs = new float[numClasses];
return c;
}
@Override
// 识别过程
public List recognizeAudio(String fileName) {
// Log this method so that it can be analyzed with systrace.
Trace.beginSection("recognizeAudio");
Trace.beginSection("preprocessAudio");
// Preprocess the audio data to normalized float based
// on the provided parameters.
// 将音频文件构建成输入数组,输入数组是一维float数组,所以在Tensorflow中的特征要转变为一维
double[][] data = RFFT.inputData(fileName);
for (int i = 0; i < data.length; ++i) {
for (int j = 0 ; j < data[0].length ; ++j) {
floatValues[i * 40 + j] = (float)data[i][j];
}
}
Trace.endSection();
// Copy the input data into TensorFlow.
Trace.beginSection("fillNodeFloat");
//将输入放入InferenceInterface
inferenceInterface.fillNodeFloat(
inputName, new int[]{1,40,40,1}, floatValues);
Trace.endSection();
// Trace.beginSection("fillNodeFloat");
// inferenceInterface.fillNodeFloat(inputName2,new int[]{2,2},floatValues);
// Trace.endSection();
// Run the inference call.
Trace.beginSection("runInference");
// 运行模型
inferenceInterface.runInference(outputNames);
Trace.endSection();
// Copy the output Tensor back into the output array.
Trace.beginSection("readNodeFloat");
// 获得输出的confidence到outputs数组里
inferenceInterface.readNodeFloat(outputName, outputs);
Trace.endSection();
// Find the best classifications.
// 用PriorityQueue获取top-3,这儿的Recognition来自于接口Classifier,是一个bean类
PriorityQueue pq =
new PriorityQueue(
3,
new Comparator() {
@Override
public int compare(Recognition lhs, Recognition rhs) {
// Intentionally reversed to put high confidence at the head of the queue.
return Float.compare(rhs.getConfidence(), lhs.getConfidence());
}
});
for (int i = 0; i < outputs.length; ++i) {
if (outputs[i] > THRESHOLD) {
pq.add(
//构建bean类,参数是label,label name,confidence
new Recognition(
"" + i, labels.size() > i ? labels.get(i) : "unknown", outputs[i]));
}
}
final ArrayList recognitions = new ArrayList();
int recognitionsSize = Math.min(pq.size(), MAX_RESULTS);
if(recognitionsSize == 0){
return null;
}
for (int i = 0; i < recognitionsSize; ++i) {
recognitions.add(pq.poll());
}
Trace.endSection(); // "recognizeAudio"
return recognitions;
}
@Override
public void enableStatLogging(boolean debug) {
inferenceInterface.enableStatLogging(debug);
}
@Override
public String getStatString() {
return inferenceInterface.getStatString();
}
@Override
public void close() {
inferenceInterface.close();
}
}
我们成功地构建了用来识别分类的TensorFlowAudioClassifier,下面展示一下如何使用我们构建的类:
public static TensorFlowAudioClassifier classifier;
private static final String INPUT_NAME = "input";
private static final String OUTPUT_NAME = "output";
private static final String MODEL_FILE = "file:///android_asset/acoustic.pb";
private static final String LABEL_FILE =
"file:///android_asset/eventLabel.txt";
private static int INPUT_SIZE = 40;
try {
// 获取分类器
classifier = (TensorFlowAudioClassifier) TensorFlowAudioClassifier.create(
getAssets(),
MODEL_FILE,
LABEL_FILE,
INPUT_SIZE,
INPUT_NAME,
OUTPUT_NAME
);
// 识别对应Audio的类别
Recognition result = classifier.recognizeAudio(fileName);
} catch (IOException e) {
e.printStackTrace();
}