android studio 3.+使用tensorflow lite将pb文件应用于android端

目录

  • 一、模型训练、保存和导入
    • 1.1、模型训练
    • 1.2、模型保存
    • 1.3、模型导入
  • 二、移植到Android
    • 2.1、下载jar包和so库
    • 2.2、Android Studio配置
    • 2.3、模型调用
  • 三、GitHub代码
  • 四、特别致谢

参考链接: https://blog.csdn.net/guyuealian/article/details/79672257
tensorflow模型GitHub地址: https://github.com/ChaoflyLi/MnistAndroid
android工程: https://github.com/ChaoflyLi/MnistToAndroid
写博客,踩坑不易 麻烦给个“star”哈

一、模型训练、保存和导入

1.1、模型训练

首先,需要定义模型的输入层和输出层节点的名字(通过形参 'name’指定,名字可以随意,后面加载模型时,都是通过该name来传递数据的):

x = tf.placeholder(tf.float32,[None,784],name='x_input')#输入节点:x_input
.
.
.
pre_num=tf.argmax(y,1,output_type='int32',name="output")#输出节点:output

特别注意:一定要特别注意
这个地方的output_type=‘int32’,我当时没看参考博主的这个地方,然后总是出错,造成了闪退

TensorFlow默认类型是float32,但我们希望返回的是一个int型,因此需要指定output_type=‘int32’;但注意了,在Windows下测试使用int64和float64都是可以的,但在Android平台上只能使用int32和float32,并且对应Java的int和float类型。

鄙人觉得最终的结果以int类型在java中使用,所以这个地方应该也是int

鄙人觉得
python代码: (上面的代码)
pre_num=tf.argmax(y,1,output_type=‘int32’,name=“output”)
java代码: (下面的代码)
int[] outputs = new int[OUT_COL*OUT_ROW]
这两部分应该相互对应

/**
     *  利用训练好的TensoFlow模型预测结果
     * @param bitmap 输入被测试的bitmap图
     * @return 返回预测结果,int数组
     */
    public int[] getPredict(Bitmap bitmap) {
        float[] inputdata = bitmapToFloatArray(bitmap,28, 28);//需要将图片缩放带28*28
        //将数据feed给tensorflow的输入节点
        inferenceInterface.feed(inputName, inputdata, IN_COL, IN_ROW);
        //运行tensorflow
        String[] outputNames = new String[] {outputName};
        inferenceInterface.run(outputNames);
        ///获取输出节点的输出信息
        int[] outputs = new int[OUT_COL*OUT_ROW]; //用于存储模型的输出数据
        inferenceInterface.fetch(outputName, outputs);
        return outputs;
    }

1.2、模型保存

模型保存为pb文件

output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output'])
with tf.gfile.FastGFile('model/mnist.pb', mode='wb') as f:  # ’wb’中w代表写文件,b代表将数据以二进制方式写入文件。
    f.write(output_graph_def.SerializeToString())

完整的python代码

# coding=utf-8
# 单隐层SoftMax Regression分类器:训练和保存模型模块
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
from tensorflow.python.framework import graph_util

print('tensortflow:{0}'.format(tf.__version__))

mnist = input_data.read_data_sets("Mnist_data/", one_hot=True)


x = tf.placeholder(tf.float32, [None, 784], name='x_input')  # 输入节点名:x_input
y_ = tf.placeholder(tf.float32, [None, 10], name='y_input')


dense1 = tf.layers.dense(inputs=x,
                         units=512,
                         activation=tf.nn.relu,
                         kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
                         kernel_regularizer=tf.nn.l2_loss)
dense2 = tf.layers.dense(inputs=dense1,
                         units=512,
                         activation=tf.nn.relu,
                         kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
                         kernel_regularizer=tf.nn.l2_loss)
logits = tf.layers.dense(inputs=dense2,
                         units=10,
                         activation=None,
                         kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
                         kernel_regularizer=tf.nn.l2_loss, name='logit')
y = tf.nn.softmax(logits, name='final_result')

# 定义损失函数和优化方法
with tf.name_scope('loss'):
    loss = -tf.reduce_sum(y_ * tf.log(y))
with tf.name_scope('train_step'):
    train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
    print(train_step)
# 初始化
sess = tf.InteractiveSession()
init = tf.global_variables_initializer()
sess.run(init)
# 训练
for step in range(100):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    train_step.run({x: batch_xs, y_: batch_ys})

# 测试模型准确率
pre_num = tf.argmax(y, 1, output_type='int32', name="output")  # 输出节点名:output
correct_prediction = tf.equal(pre_num, tf.argmax(y_, 1, output_type='int32'))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
a = accuracy.eval({x: mnist.test.images, y_: mnist.test.labels})
print('测试正确率:{0}'.format(a))

# 保存训练好的模型
# 形参output_node_names用于指定输出的节点名称,output_node_names=['output']对应pre_num=tf.argmax(y,1,name="output"),
output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output'])
with tf.gfile.FastGFile('model/mnist.pb', mode='wb') as f:  # ’wb’中w代表写文件,b代表将数据以二进制方式写入文件。
    f.write(output_graph_def.SerializeToString())
sess.close()

1.3、模型导入

在tensorflow中使用训练好的模型
在我的test.py文件中,给出了使用模型的代码

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)
pb_file_path = './model/mnist.pb'

with tf.Graph().as_default():
    output_graph_def = tf.GraphDef()
    with open(pb_file_path, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        _ = tf.import_graph_def(output_graph_def, name="")

    with tf.Session() as sess:
        pre_label = sess.graph.get_tensor_by_name("output:0")
        print(pre_label)
        for i in range(100):
            batch_xs, batch_ys = mnist.train.next_batch(50)
            a = sess.run(pre_label, feed_dict={'x_input:0': batch_xs})
            prediction = np.array(a)
            print(prediction)
            # print(prediction.argmax(axis=1))

二、移植到Android

2.1、下载jar包和so库

我们需要的是libtensorflow_inference.so和libandroid_tensorflow_inference_java.jar,其余就是在Android Studio配置的问题。点击这两个文件可以找到其位置。
也可以在网盘中下载
.jar的网盘链接
链接: https://pan.baidu.com/s/1rbXe75HPjFSiWCwJNrXg0g
提取码: 1nff
复制这段内容后打开百度网盘手机App,操作更方便哦
.so的网盘链接:
链接: https://pan.baidu.com/s/1BR9B8FL8XDML-2-sKUbNTQ
提取码: swtt
复制这段内容后打开百度网盘手机App,操作更方便哦

2.2、Android Studio配置

  1. 新建一个android项目
  2. 在android模式下,新建一个assets目录并把pb文件放入其中,如下图所示
    android studio 3.+使用tensorflow lite将pb文件应用于android端_第1张图片
  3. 将下载的libtensorflow_inference.so和libandroid_tensorflow_inference_java.jar如下结构放在libs文件夹下,如下图所示
    android studio 3.+使用tensorflow lite将pb文件应用于android端_第2张图片
    4.把libandroid_tensorflow_inference_java.jar添加到依赖项,左键点击文件,右键点击,出现一个大的界面,在最下面找到‘Adds as Library’
    android studio 3.+使用tensorflow lite将pb文件应用于android端_第3张图片
    5、app\build.gradle配置
    在defaultConfig中添加
   multiDexEnabled true
        ndk {
            abiFilters "armeabi-v7a"
        }

增加sourceSets

    sourceSets {
        main {
            jniLibs.srcDirs = ['libs']
        }
    }

android studio 3.+使用tensorflow lite将pb文件应用于android端_第4张图片
**注意:**如果第4步,已经完成了则不需要这一步,若没有
在dependencies中增加TensoFlow编译的jar文件libandroid_tensorflow_inference_java.jar:

	implementation files('libs/libandroid_tensorflow_inference_java.jar')

android studio 3.+使用tensorflow lite将pb文件应用于android端_第5张图片
6、配置gradle.properties
添加代码:

android.useDeprecatedNdk=true

android studio 3.+使用tensorflow lite将pb文件应用于android端_第6张图片
OK了,build.gradle配置完成了,剩下的就是java编程的问题了。

2.3、模型调用

在需要调用TensoFlow的地方,加载so库“System.loadLibrary(“tensorflow_inference”);并”import org.tensorflow.contrib.android.TensorFlowInferenceInterface;就可以使用了

TensorFlowInferenceInterface.feed()//送入输入数据
TensorFlowInferenceInterface.run()//进行模型计算
TensorFlowInferenceInterface.fetch()//获取输出数据
package com.example.mnistandroid;

import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.Color;
import android.graphics.Matrix;
import android.util.Log;

import org.tensorflow.contrib.android.TensorFlowInferenceInterface;


public class PredictionTF {
    private static final String TAG = "PredictionTF";
    //设置模型输入/输出节点的数据维度
    private static final int IN_COL = 1;
    private static final int IN_ROW = 28*28;
    private static final int OUT_COL = 1;
    private static final int OUT_ROW = 1;
    //模型中输入变量的名称
    private static final String inputName = "x_input";
    //模型中输出变量的名称
    private static final String outputName = "output";

    TensorFlowInferenceInterface inferenceInterface;
    static {
        //加载libtensorflow_inference.so库文件
        System.loadLibrary("tensorflow_inference");
        Log.e(TAG,"libtensorflow_inference.so库加载成功");
    }

    PredictionTF(AssetManager assetManager, String modePath) {
        //初始化TensorFlowInferenceInterface对象
        inferenceInterface = new TensorFlowInferenceInterface(assetManager,modePath);
        Log.e(TAG,"TensoFlow模型文件加载成功");
    }

    /**
     *  利用训练好的TensoFlow模型预测结果
     * @param bitmap 输入被测试的bitmap图
     * @return 返回预测结果,int数组
     */
    public int[] getPredict(Bitmap bitmap) {
        float[] inputdata = bitmapToFloatArray(bitmap,28, 28);//需要将图片缩放带28*28
        //将数据feed给tensorflow的输入节点
        inferenceInterface.feed(inputName, inputdata, IN_COL, IN_ROW);
        //运行tensorflow
        String[] outputNames = new String[] {outputName};
        inferenceInterface.run(outputNames);
        ///获取输出节点的输出信息
        int[] outputs = new int[OUT_COL*OUT_ROW]; //用于存储模型的输出数据
        inferenceInterface.fetch(outputName, outputs);
        return outputs;
    }

    /**
     * 将bitmap转为(按行优先)一个float数组,并且每个像素点都归一化到0~1之间。
     * @param bitmap 输入被测试的bitmap图片
     * @param rx 将图片缩放到指定的大小(列)->28
     * @param ry 将图片缩放到指定的大小(行)->28
     * @return   返回归一化后的一维float数组 ->28*28
     */
    public static float[] bitmapToFloatArray(Bitmap bitmap, int rx, int ry){
        int height = bitmap.getHeight();
        int width = bitmap.getWidth();
        // 计算缩放比例
        float scaleWidth = ((float) rx) / width;
        float scaleHeight = ((float) ry) / height;
        Matrix matrix = new Matrix();
        matrix.postScale(scaleWidth, scaleHeight);
        bitmap = Bitmap.createBitmap(bitmap, 0, 0, width, height, matrix, true);
        Log.i(TAG,"bitmap width:"+bitmap.getWidth()+",height:"+bitmap.getHeight());
        Log.i(TAG,"bitmap.getConfig():"+bitmap.getConfig());
        height = bitmap.getHeight();
        width = bitmap.getWidth();
        float[] result = new float[height*width];
        int k = 0;
        //行优先
        for(int j = 0;j < height;j++){
            for (int i = 0;i < width;i++){
                int argb = bitmap.getPixel(i,j);
                int r = Color.red(argb);
                int g = Color.green(argb);
                int b = Color.blue(argb);
                int a = Color.alpha(argb);
                //由于是灰度图,所以r,g,b分量是相等的。
                assert(r==g && g==b);
//                Log.i(TAG,i+","+j+" : argb = "+argb+", a="+a+", r="+r+", g="+g+", b="+b);
                result[k++] = r / 255.0f;
            }
        }
        return result;
    }
}
  • 简单说明一下:项目新建了一个PredictionTF类,该类会先加载libtensorflow_inference.so库文件;PredictionTF(AssetManager assetManager, String modePath) 构造方法需要传入AssetManager对象和pb文件的路径;
  • 从资源文件中获取BitMap图片,并传入 getPredict(Bitmap bitmap)方法,该方法首先将BitMap图像缩放到2828的大小,由于原图是灰度图,我们需要获取灰度图的像素值,并将2828的像素转存为行向量的一个float数组,并且每个像素点都归一化到0~1之间,这个就是bitmapToFloatArray(Bitmap bitmap, int rx, int ry)方法的作用;
  • 然后将数据feed给tensorflow的输入节点,并运行(run)tensorflow,最后获取(fetch)输出节点的输出信息。
    activity_main布局文件:


    
    

三、GitHub代码

tensorflow模型GitHub地址:https://github.com/ChaoflyLi/MnistAndroid
android工程:https://github.com/ChaoflyLi/MnistToAndroid

四、特别致谢

参考链接:https://blog.csdn.net/guyuealian/article/details/79672257

你可能感兴趣的:(Python,tensorflow,Android)