深度学习 - TensorFlow Lite模型,云侧训练与安卓端侧推理

TensorFlow Lite模型,云侧训练与安卓端侧推理

  • 引言
  • 一、云侧深度模型的训练代码
    • 1.加载数据集的格式分析
      • 1.1 从数据集加载的数据格式
      • 1.2 对加载的数据进行处理
    • 2. 深度模型搭建
    • 3. 模型训练、评估、保存、转换
    • 4. 模型预测
  • 二、端侧安卓的推理代码
    • 1. 安卓项目配置
      • 1.1 app.gradle引入依赖
      • 1.2 AndroidManifest.xml新增照相机权限
      • 1.3 模型放置
    • 2. 安卓端侧代码实现
      • 2.1 布局文件
      • 2.2 主函数文件
      • 2.3 mnist数据集工具类
  • 三、测试结果
  • 参考网址
  • 总结

引言

本次博客主要基于TensorFlow官网的demo进行学习,把学习过程的心得理解记录。其主要内容为TensorFlow云侧训练深度模型,并转换为手机端lite深度模型,最后在安卓手机端侧利用该模型进行推理得出预测结果。本次学习以mnist数据集为例,毕竟入手深度学习,mnist相当于学习编程语言的Hello World!利用的工具有Anaconda的Jupyter Notebook,和Android Studio。

一、云侧深度模型的训练代码

1.加载数据集的格式分析

import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt

class MNISTLoader():
    def __init__(self):
        mnist = tf.keras.datasets.mnist
        (self.train_data, self.train_label), (self.test_data, self.test_label) = mnist.load_data()

        # MNIST中的圖片預設為uint8(0-255的數字)。以下程式碼將其正規化到0-1之間的浮點數,並在最後增加一維作為顏色通道
        self.train_data = np.expand_dims(self.train_data.astype(np.float32) / 255.0, axis=-1)      # [60000, 28, 28, 1]
        self.test_data = np.expand_dims(self.test_data.astype(np.float32) / 255.0, axis=-1)        # [10000, 28, 28, 1]
        self.train_label = self.train_label.astype(np.int32)    # [60000]
        self.test_label = self.test_label.astype(np.int32)      # [10000]
        self.num_train_data, self.num_test_data = self.train_data.shape[0], self.test_data.shape[0]

导入TensorFlow和numpy包即可,我们会用到TensorFlow的Keras,它是用 Python 编写的高级神经网络 API,支持快速的构建网络框架。

1.1 从数据集加载的数据格式

先对MNISTLoader这个类进行分析,该类先加载了数据集数据,如下。

(train_data, train_label), (test_data, test_label) = mnist.load_data()

打印数据格式如下。

print("train_data:变量类型={0},变量形状={1},数据类型={2}".format(type(train_data), train_data.shape, train_data.dtype))
print("train_label:变量类型={0},变量形状={1},数据类型={2}".format(type(train_label), train_label.shape,train_label.dtype))
print("test_data:变量类型={0},变量形状={1},数据类型={2}".format(type(test_data), test_data.shape,test_data.dtype))
print("test_label:变量类型={0},变量形状={1},数据类型={2}".format(type(test_label), test_label.shape,test_label.dtype))

打印结果如下。

train_data:变量类型=<class 'numpy.ndarray'>,变量形状=(60000, 28, 28),数据类型=uint8
train_label:变量类型=<class 'numpy.ndarray'>,变量形状=(60000,),数据类型=uint8
test_data:变量类型=<class 'numpy.ndarray'>,变量形状=(10000, 28, 28),数据类型=uint8
test_label:变量类型=<class 'numpy.ndarray'>,变量形状=(10000,),数据类型=uint8

也就是说加载了60000张28×28的图片作为训练集,10000张28×28的图片作为测试集。其中的数据类型为uint8,取值为0~255。
接着又用了np.expand_dims()为图片的数据集进行了维度扩展,axis=-1表示在原来的变量形状的最后一个维度增加多一维,-1在python的索引通常都是表示最后一个索引。为什么要增加这么个维度呢?因为最后一个维度的数值表示图片的通道数。比如图片为RGB图时,最后一个维度的数值是3,而mnist的数据集为灰度图片,即单通道表示的图片,所以最后一个维度数值是1。train_label、test_label的数据则是用0~9表示对应数据集的各个类。

1.2 对加载的数据进行处理

对加载的数据进行的运算,主要包括对图片进行0~1数值的归一化,维度扩展,和数据类型转换;对标签值进行数值类型转换。注意对数值类型转换尤为重要,这跟后续在安卓端编程中需要用到什么数据类型来作为输入输出要对应起来。数据转换的语句如下。

train_data = np.expand_dims(train_data.astype(np.float32) / 255.0, axis=-1)
train_label = train_label.astype(np.int32)

再次运行如下语句查看数据格式

print("train_data:变量类型={0},变量形状={1},数据类型={2}".format(type(train_data), train_data.shape, train_data.dtype))
print("train_label:变量类型={0},变量形状={1},数据类型={2}".format(type(train_label), train_label.shape, train_label.dtype))

得到了新的数据格式,作为最终输入到模型进行训练的数据格式

train_data:变量类型=<class 'numpy.ndarray'>,变量形状=(60000, 28, 28, 1),数据类型=float32
train_label:变量类型=<class 'numpy.ndarray'>,变量形状=(60000,),数据类型=int32

2. 深度模型搭建

用Keras的Sequential来按顺序搭建模型,超级简单。需要添加的神经网络层,只需要add进来就可以了,Keras提供了很多常用的网络层。同时目前最新版本的Keras搭建模型时,每一层(包括首层输入层)的输入会根据上一层的输出自动推断,所以不需要input_shape参数。

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Conv2D(
    filters=32,             # 卷积滤波器数量
    kernel_size=[5, 5],     # 卷积核大小
    padding="same",         # padding策略
    activation=tf.nn.relu   # 激活函数
))
model.add(tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2))
model.add(tf.keras.layers.Conv2D(
    filters=64,
    kernel_size=[5, 5],
    padding="same",
    activation=tf.nn.relu
))
model.add(tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2))
model.add(tf.keras.layers.Reshape(target_shape=(7 * 7 * 64,)))
model.add(tf.keras.layers.Dense(units=1024, activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(units=10, activation=tf.nn.softmax))

深度学习 - TensorFlow Lite模型,云侧训练与安卓端侧推理_第1张图片

3. 模型训练、评估、保存、转换

num_epochs = 20
batch_size = 50
learning_rate = 0.001
save_path = r"D:\code\jupyter\saved"

# 数据加载器
data_loader = MNISTLoader()

# 模型编译
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
    loss=tf.keras.losses.sparse_categorical_crossentropy,
    metrics=[tf.keras.metrics.sparse_categorical_accuracy]
)

# 模型训练
model.fit(data_loader.train_data, data_loader.train_label,
          epochs=num_epochs, batch_size=batch_size)

# 模型评估
print(model.evaluate(data_loader.test_data, data_loader.test_label))

# 模型保存
model.save(save_path)

# 模型转换
converter = tf.lite.TFLiteConverter.from_saved_model(save_path)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant_model = converter.convert()
open(os.path.join(save_path, "mnist_savedmodel_quantized.tflite"),
     "wb").write(tflite_quant_model)

模型的损失函数采用了sparse_categorical_crossentropy,则不同类的label直接用数字表示就可以了,如数字2的图片对应的label值为2。

模型训练时会动态给出结果如下:

1200/1200 [==============================] - 42s 35ms/step - loss: 0.0249 - sparse_categorical_accuracy: 0.9924

模型评估时会动态给出结果如下:

313/313 [==============================] - 2s 6ms/step - loss: 0.0375 - sparse_categorical_accuracy: 0.9881

最后模型mnist_savedmodel_quantized.tflite保存到了相应的路径save_path,同时,利用转换器转换为适合安卓手机端使用的量化模型。

4. 模型预测

# 找测试数据第一张图片来看看,展示的时候shape是28*28
im = data_loader.test_data[0].reshape(28, 28)
fig = plt.figure()
plotwindow = fig.add_subplot(111)
plt.axis('off')
plt.imshow(im, cmap='gray')
plt.show()
plt.close()

深度学习 - TensorFlow Lite模型,云侧训练与安卓端侧推理_第2张图片
预测图片如下:

im = im.reshape(1, 28, 28, 1)
print("各个类的概率:{0}".format(model.predict(im)))
print("最大概率的类:{0}".format(model.predict_classes(im))) 

关于模型的输入格式,由于我们在构建model的时候,首层Conv2D没有使用data_format参数,其默认输入格式为channels_last,即batch_shape + (spatial_dim1, spatial_dim2, spatial_dim3, channels)。所以reshape的第一个数字是batch_size,最后一个数字是颜色通道数。

输出结果如下:

各个类的概率:[[9.9865129e-09 4.3024698e-08 5.2642001e-05 3.9080669e-06 2.2962024e-10
  2.2086294e-07 5.7997704e-13 9.9992096e-01 2.2194282e-08 2.2103426e-05]]
最大概率的类:[7]

通过上面的例子可知,我们直接预测的输出是一个包含各个类的预测概率的数组,而通过model.predict_classes(im)则会拿到预测数组里分值最高的数值对应的索引,model.predict_classes() 该方法将会被抛弃,提示使用np.argmax(model.predict(x), axis=-1)

二、端侧安卓的推理代码

安卓端实现通过调用相机获取图片输入,接着通过模型推理后打印日志输出结果。

1. 安卓项目配置

1.1 app.gradle引入依赖

android {
    aaptOptions {
        noCompress "tflite" // 编译apk时,不压缩tflite文件
    }
}
dependencies {
    implementation 'org.tensorflow:tensorflow-lite:2.4.0' // 推理工具
    implementation 'org.tensorflow:tensorflow-lite-support:0.2.0' // 用于读取加载模型
}

1.2 AndroidManifest.xml新增照相机权限

<uses-permission android:name="android.permission.CAMERA" />

1.3 模型放置

把转换后的模型mnist_savedmodel_quantized.tflite放置到src\main\assets目录下,没该目录的需新建一个。

2. 安卓端侧代码实现

2.1 布局文件


<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:tools="http://schemas.android.com/tools"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    tools:context=".MainActivity"
    android:orientation="vertical"
    android:gravity="center">

    <ImageView
        android:id="@+id/camera_image"
        android:layout_weight="1"
        android:layout_width="wrap_content"
        android:layout_height="0dp">
    ImageView>
    <Button
        android:id="@+id/open_camera_button"
        android:text="打开相机"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content">
    Button>

LinearLayout>

2.2 主函数文件

package com.example.tensorflowlite;

import java.io.IOException;

import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.support.common.FileUtil;

import androidx.annotation.Nullable;
import androidx.appcompat.app.AppCompatActivity;
import androidx.core.app.ActivityCompat;
import androidx.core.content.ContextCompat;

import android.Manifest;
import android.content.Intent;
import android.content.pm.PackageManager;
import android.graphics.Bitmap;
import android.os.Bundle;
import android.provider.MediaStore;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.ImageView;

/**
 * 主活动页,通过点击底部打开相机按钮,拍照后返回主页,在主页显示照片图像
 * 同时日志打印推理的结果
 */
public class MainActivity extends AppCompatActivity implements View.OnClickListener {
    private static final String TAG = "MainActivity";

    private static final String MODEL_PATH = "mnist_savedmodel_quantized.tflite";

    private static final int CAMERA_PERMISSION_REQ_CODE = 1;

    private static final int CAMERA_CAPTURE_REQ_CODE = 2;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);
        Button button = findViewById(R.id.open_camera_button);
        button.setOnClickListener(this);
    }

    /**
     * 打开照相机
     */
    private void openCamera() {
        if (ContextCompat.checkSelfPermission(this, Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED) {
            // 无权限,引导用户授予权限
            if (ActivityCompat.shouldShowRequestPermissionRationale(this, Manifest.permission.CAMERA)) {
                // 提示已经禁止
                Log.e(TAG, "error");
            } else {
                // 请求相机权限
                ActivityCompat.requestPermissions(this, new String[] {Manifest.permission.CAMERA},
                    CAMERA_PERMISSION_REQ_CODE);
            }
        } else {
            // 有权限,直接打开相机,并等待回调
            Intent camera = new Intent(MediaStore.ACTION_IMAGE_CAPTURE);
            startActivityForResult(camera, CAMERA_CAPTURE_REQ_CODE);
        }
    }

    @Override
    public void onClick(View v) {
        switch (v.getId()) {
            case R.id.open_camera_button:
                openCamera();
                break;
            default:
                Log.i(TAG, "nothing");
        }
    }

    @Override
    protected void onActivityResult(int requestCode, int resultCode, @Nullable Intent data) {
        super.onActivityResult(requestCode, resultCode, data);
        if (resultCode == RESULT_OK && requestCode == CAMERA_CAPTURE_REQ_CODE) {
            Bundle extras = data.getExtras();

            // 拿到的数据变得很小,被压缩过了,对于mnist数据集来说,够够的了
            Bitmap bitmap = (Bitmap) extras.get("data");

            // 画出拿到的数据
            ImageView cameraImage = findViewById(R.id.camera_image);
            cameraImage.setImageBitmap(bitmap);

            // 推理
            inference(bitmap);
        }
    }

    /**
     * 对图像进行推理
     */
    private void inference(Bitmap bitmap) {
        try {
            // 加载模型后的解释器
            Interpreter interpreter =
                new Interpreter(FileUtil.loadMappedFile(this, MODEL_PATH), new Interpreter.Options());

            // 新建变量,用于存放推理输出结果
            float[][] labelProbArray = new float[1][10];

            // 开始推断
            interpreter.run(MnistUtil.convertBitmapToByteBuffer(bitmap), labelProbArray);

            // 打印推断结果,顺序按
            for (int i = 0; i < labelProbArray[0].length; i++) {
                Log.i(TAG, labelProbArray[0][i] + "");
            }

        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}

在主活动页中,通过点击底部打开相机按钮,拍照后返回主页,在主页显示照片图像同时日志打印推理的结果。主要的函数有:openCamera()打开相机,onActivityResult(int requestCode, int resultCode, @Nullable Intent data)等待相机回调结果获取图片,inference(Bitmap bitmap)对图像进行推理,同时显示图像和打印推理结果。

2.3 mnist数据集工具类

package com.example.tensorflowlite;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;

import android.graphics.Bitmap;

/**
 * mnist数据集工具
 */
public class MnistUtil {
    public static ByteBuffer convertBitmapToByteBuffer(Bitmap bitmap) {
        // 定义图像的宽高
        int dimImgWidth = 28;
        int dimImgHeight = 28;

        // 推理时,一次只推理一张图像
        int dimBatchSize = 1;

        // 相当于云侧训练时用np.expand_dims多扩展出来的一维
        int dimPixelSize = 1;

        // 一个float等于4个字节
        int numBytesPerChannel = 4;

        // 存放图像数据的数组
        int[] intValues = new int[dimImgWidth * dimImgHeight];

        // 缩放图像至 28 * 28
        Bitmap scaleBitmap = Bitmap.createScaledBitmap(bitmap, dimImgWidth, dimImgHeight, true);

        // 复制缩放后的bitmap到存放图像数据的数组
        scaleBitmap.getPixels(intValues, 0, scaleBitmap.getWidth(), 0, 0, scaleBitmap.getWidth(),
            scaleBitmap.getHeight());

        // 创建图像数据缓冲区
        ByteBuffer imgData =
            ByteBuffer.allocateDirect(numBytesPerChannel * dimBatchSize * dimImgWidth * dimImgHeight * dimPixelSize);

        // ByteBuffer的字节序设置为当前硬件平台的字节序
        imgData.order(ByteOrder.nativeOrder());

        // 把position设为0,limit不变,一般在把数据重写入Buffer前调用。
        imgData.rewind();

        // 处理图像数据,归一化为0~1的浮点型数据,并把存放图像数据的数组里的数组往缓冲器拷贝
        int pixel = 0;
        for (int i = 0; i < dimImgWidth; ++i) {
            for (int j = 0; j < dimImgHeight; ++j) {
                int val = intValues[pixel++];

                // 添加把Pixel数值转化并添加到ByteBuffer
                addImgValue(imgData, val);
            }
        }
        return imgData;
    }

    /**
     * 添加图像数据值。对图像数据进行处理,归一化至0~1.0的浮点数据
     *
     * @param imgData 缓冲区数据
     * @param val 整形数据
     */
    private static void addImgValue(ByteBuffer imgData, int val) {
        int mImageMean = 0;
        float mImageStd = 255.0f;
        imgData.putFloat(((val & 0xFF) - mImageMean) / mImageStd);
    }
}

注意这里的图像缓冲区大小为什么要乘以4:ByteBuffer.allocateDirect(numBytesPerChannel * dimBatchSize * dimImgWidth * dimImgHeight * dimPixelSize)创建了一个4×1×28×28×1大小的缓冲区存储图片,因为缓冲区是以字节byte来存储的,通过计算,每个图像像素点最终转化为float型,而float型在java虚拟机中以4个字节存在,所以需要乘以4。在图像比较大的时候,缓冲区是很重要的。

三、测试结果

深度学习 - TensorFlow Lite模型,云侧训练与安卓端侧推理_第3张图片
日志打印所有类的概率如下:

2021-07-08 10:13:03.325 15543-15543/com.example.tensorflowlite I/MainActivity: 0.0030371095
2021-07-08 10:13:03.325 15543-15543/com.example.tensorflowlite I/MainActivity: 0.003125498
2021-07-08 10:13:03.325 15543-15543/com.example.tensorflowlite I/MainActivity: 0.011447249
2021-07-08 10:13:03.325 15543-15543/com.example.tensorflowlite I/MainActivity: 0.055658735
2021-07-08 10:13:03.325 15543-15543/com.example.tensorflowlite I/MainActivity: 7.467345E-5
2021-07-08 10:13:03.325 15543-15543/com.example.tensorflowlite I/MainActivity: 0.05097304
2021-07-08 10:13:03.325 15543-15543/com.example.tensorflowlite I/MainActivity: 1.911169E-5
2021-07-08 10:13:03.325 15543-15543/com.example.tensorflowlite I/MainActivity: 0.8677362
2021-07-08 10:13:03.326 15543-15543/com.example.tensorflowlite I/MainActivity: 9.3077944E-4
2021-07-08 10:13:03.326 15543-15543/com.example.tensorflowlite I/MainActivity: 0.006997687

结果为0~9按顺序打印后,可以看到数字7的概率为0.8677362。

参考网址

官方安卓端侧代码
官方云侧训练模型代码
Keras中文文档
TensorFlow Lite中文文档

总结

你学会了吗?

你可能感兴趣的:(笔记)