基于现有 TensorFlow 模型构建 Android 应用

在之前写的一篇文章 TensorFlow,从一个 Android Demo 开始 中通过编译官方的 Demo 接触到了 TensorFlow 实际使用场景。这篇文章打算从一个Android 开发者的角度切入,看看构建一个基于 TensorFlow 的 Android 应用的完整流程。

相关代码可查看:GitHub 项目地址

通过 TensorFlow 用已有模型构建 Android 应用

在 Google 的 TensorFlow examples project 中,有一个 Sample 叫作 TF Classify,它通过使用 Google Inception 模型对实时的相机图像帧进行分类,并显示展示当前图像的分类推断结果。

基于现有 TensorFlow 模型构建 Android 应用_第1张图片
TF Classify

下面我们就基于这个现有模型,在 Android 平台上实现一个可以对物品进行分类的图像识别应用。

获取数据模型

这里可以直接下载 Google 提供的一个数据模型 inception5h.zip ,其中 .pb 后缀的文件是已经训练好的模型,而 .txt 对应的是训练数据包含的所有标签。

基于现有 TensorFlow 模型构建 Android 应用_第2张图片

这个模型可对 1008 种物品识别分类,具体有哪些类可以查看标签信息,至于每个类别到底训练了多少张图片就不得而知了。

基于现有 TensorFlow 模型构建 Android 应用_第3张图片

在 Android 项目中引入 TensorFlow

跟在项目中集成其他第三库一样,先在 build.gradle 中添加对 TensorFlow 的依赖。

compile 'org.tensorflow:tensorflow-android:1.6.0'

这里我们直接使用了 Google 为我们编译好的 TensorFlow 现成库了,如果你想自行对 TensorFlow 进行 NDK 交叉编译得到库文件也可以。

图像识别功能的实现

复制模型文件到项目 assets 文件夹:
如下图所示,我们在项目 assets 文件夹下创建一个 model 文件夹,并把之前下载的 inception5h.zip 解压后的全部文件复制到该文件夹下。

基于现有 TensorFlow 模型构建 Android 应用_第4张图片

添加模型调用的相关类
因为我们要实现的功能和官方 demo 相似,只是训练的有所模型不同。既然对模型的使用方式是一样的,那这里就直接使用 Google demo 项目中提供的 Classifier.java 和 TensorFlowImageClassifier.java 这两个类来实现。

我们可以先跳过这部分内容的具体实现,等到对整体流程有个大致认识后再回过头来消化掉,这样可以更好地去理解。

这里我们重点关注下面两个方法,一个是 TensorFlowImageClassifier 的静态方法 create 方法:

   /**
     * Initializes a native TensorFlow session for classifying images.
     *
     * @param assetManager The asset manager to be used to load assets.
     * @param modelFilename The filepath of the model GraphDef protocol buffer.
     * @param labelFilename The filepath of label file for classes.
     * @param inputSize The input size. A square image of inputSize x inputSize is assumed.
     * @param imageMean The assumed mean of the image values.
     * @param imageStd The assumed std of the image values.
     * @param inputName The label of the image input node.
     * @param outputName The label of the output node.
     * @throws IOException
     */
    public static Classifier create(AssetManager assetManager, String modelFilename, String labelFilename,
            int inputSize, int imageMean, float imageStd, String inputName, String outputName) 

该方法需要传入模型相关的参数进行初始化,完成后返回一个 Classifier 实例。

通过 Classifier 对象,我们可以调用其 recognizeImage 方法来识别我们传入的 bitmap 图像数据,该方法会返回图像类别后对物品类别进行推断的标签结果:

/**
 * 进行图片识别
 */
 public List recognizeImage(final Bitmap bitmap) 

相关主要功能代码的实现:
相关代码可查看:GitHub 项目地址

public class MainActivity extends AppCompatActivity implements View.OnClickListener {
    ...
    
    // 模型相关配置
    private static final int INPUT_SIZE = 224;
    private static final int IMAGE_MEAN = 117;
    private static final float IMAGE_STD = 1;
    private static final String INPUT_NAME = "input";
    private static final String OUTPUT_NAME = "output";
    private static final String MODEL_FILE = "file:///android_asset/model/tensorflow_inception_graph.pb";
    private static final String LABEL_FILE = "file:///android_asset/model/imagenet_comp_graph_label_strings.txt";

    private Executor executor;
    private Uri currentTakePhotoUri;

    private TextView result;
    private ImageView ivPicture;
    private Classifier classifier;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);

        if (!isTaskRoot()) {
            finish();
        }

        setContentView(R.layout.activity_main);

        findViewById(R.id.iv_choose_picture).setOnClickListener(this);
        findViewById(R.id.iv_take_photo).setOnClickListener(this);

        ivPicture = findViewById(R.id.iv_picture);
        result = findViewById(R.id.tv_classifier_info);

        // 避免耗时任务占用 CPU 时间片造成UI绘制卡顿,提升启动页面加载速度
        Looper.myQueue().addIdleHandler(idleHandler);

    }

    /**
     *  主线程消息队列空闲时(视图第一帧绘制完成时)处理耗时事件
     */
    MessageQueue.IdleHandler idleHandler = new MessageQueue.IdleHandler() {
        @Override
        public boolean queueIdle() {
            // 初始化 Classifier
            if (classifier == null) {
                // 创建 TensorFlowImageClassifier
               classifier = TensorFlowImageClassifier.create(MainActivity.this.getAssets(),
                       MODEL_FILE, LABEL_FILE, INPUT_SIZE, IMAGE_MEAN, IMAGE_STD, INPUT_NAME, OUTPUT_NAME);
            }

            // 初始化线程池
            executor = new ScheduledThreadPoolExecutor(1, new ThreadFactory() {
                @Override
                public Thread newThread(@NonNull Runnable r) {
                    Thread thread = new Thread(r);
                    thread.setDaemon(true);
                    thread.setName("ThreadPool-ImageClassifier");
                    return thread;
                }
            });
            // 请求权限
            requestMultiplePermissions();
            // 返回 false 时只会回调一次
            return false;
        }
    };

    @Override
    public void onClick(View view) {
        switch (view.getId()) {
            case R.id.iv_choose_picture :
                choosePicture();
                break;
            case R.id.iv_take_photo :
                takePhoto();
                break;
            default:break;
        }
    }

    /**
     * 选择一张图片并裁剪获得一个小图
     */
    private void choosePicture() {
        Intent intent = new Intent(Intent.ACTION_GET_CONTENT);
        intent.setType("image/*");
        startActivityForResult(intent, PICTURE_REQUEST_CODE);
    }

    /**
     * 使用系统相机拍照
     */
    private void takePhoto() {
        if (ContextCompat.checkSelfPermission(this, Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED) {
            ActivityCompat.requestPermissions(this, new String[]{Manifest.permission.CAMERA}, CAMERA_PERMISSIONS_REQUEST_CODE);
        } else {
            openSystemCamera();
        }
    }

    /**
     * 打开系统相机
     */
    private void openSystemCamera() {
        //调用系统相机
        Intent takePhotoIntent = new Intent();
        takePhotoIntent.setAction(MediaStore.ACTION_IMAGE_CAPTURE);

        //这句作用是如果没有相机则该应用不会闪退,要是不加这句则当系统没有相机应用的时候该应用会闪退
        if (takePhotoIntent.resolveActivity(getPackageManager()) == null) {
            Toast.makeText(this, "当前系统没有可用的相机应用", Toast.LENGTH_SHORT).show();
            return;
        }

        String fileName = "TF_" + System.currentTimeMillis() + ".jpg";
        File photoFile = new File(FileUtil.getPhotoCacheFolder(), fileName);

        if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) {
            //通过FileProvider创建一个content类型的Uri
            currentTakePhotoUri = FileProvider.getUriForFile(this, "gdut.bsx.tensorflowtraining.fileprovider", photoFile);
            //对目标应用临时授权该 Uri 所代表的文件
            takePhotoIntent.addFlags(Intent.FLAG_GRANT_READ_URI_PERMISSION);
        } else {
            currentTakePhotoUri = Uri.fromFile(photoFile);
        }

        //将拍照结果保存至 outputFile 的Uri中,不保留在相册中
        takePhotoIntent.putExtra(MediaStore.EXTRA_OUTPUT, currentTakePhotoUri);
        startActivityForResult(takePhotoIntent, TAKE_PHOTO_REQUEST_CODE);
    }

    /**
     * 处理图片
     * @param imageUri
     */
    private void handleInputPhoto(Uri imageUri) {
        // 加载图片
        GlideApp.with(MainActivity.this).asBitmap().listener(new RequestListener() {

            @Override
            public boolean onLoadFailed(@Nullable GlideException e, Object model, Target target, boolean isFirstResource) {
                Log.d(TAG,"handleInputPhoto onLoadFailed");
                Toast.makeText(MainActivity.this, "图片加载失败", Toast.LENGTH_SHORT).show();
                return false;
            }

            @Override
            public boolean onResourceReady(Bitmap resource, Object model, Target target, DataSource dataSource, boolean isFirstResource) {
                Log.d(TAG,"handleInputPhoto onResourceReady");
                startImageClassifier(resource);
                return false;
            }
        }).load(imageUri).into(ivPicture);

        result.setText("Processing...");
    }

    /**
     * 开始图片识别匹配
     * @param bitmap
     */
    private void startImageClassifier(final Bitmap bitmap) {
        executor.execute(new Runnable() {
            @Override
            public void run() {
                try {
                    Log.i(TAG, Thread.currentThread().getName() + " startImageClassifier");
                    Bitmap croppedBitmap = getScaleBitmap(bitmap, INPUT_SIZE);

                    final List results = classifier.recognizeImage(croppedBitmap);
                    Log.i(TAG, "startImageClassifier results: " + results);
                    runOnUiThread(new Runnable() {
                        @Override
                        public void run() {
                            result.setText(String.format("results: %s", results));
                        }
                    });
                } catch (IOException e) {
                    Log.e(TAG, "startImageClassifier getScaleBitmap " + e.getMessage());
                }
            }
        });
    }

   /**
     * 请求相机和外部存储权限
     */
    private void requestMultiplePermissions() {

        String storagePermission = Manifest.permission.WRITE_EXTERNAL_STORAGE;
        String cameraPermission = Manifest.permission.CAMERA;

        int hasStoragePermission = ActivityCompat.checkSelfPermission(this, storagePermission);
        int hasCameraPermission = ActivityCompat.checkSelfPermission(this, cameraPermission);

        List permissions = new ArrayList<>();
        if (hasStoragePermission != PackageManager.PERMISSION_GRANTED) {
            permissions.add(storagePermission);
        }

        if (hasCameraPermission != PackageManager.PERMISSION_GRANTED) {
            permissions.add(cameraPermission);
        }
        
        if (!permissions.isEmpty()) {
            String[] params = permissions.toArray(new String[permissions.size()]);
            ActivityCompat.requestPermissions(this, params, PERMISSIONS_REQUEST);
        }
    }

    @Override
    protected void onActivityResult(int requestCode, int resultCode, Intent data) {
        super.onActivityResult(requestCode, resultCode, data);

        if (resultCode == RESULT_OK) {
            if (requestCode == PICTURE_REQUEST_CODE) {
                // 处理选择的图片
                handleInputPhoto(data.getData());
            } else if (requestCode == OPEN_SETTING_REQUEST_COED){
                requestMultiplePermissions();
            } else if (requestCode == TAKE_PHOTO_REQUEST_CODE) {
                // 如果拍照成功,加载图片并识别
                handleInputPhoto(currentTakePhotoUri);
            }
        }
    }

    /**
     * 对图片进行缩放
     * @param bitmap
     * @param size
     * @return
     * @throws IOException
     */
    private static Bitmap getScaleBitmap(Bitmap bitmap, int size) throws IOException {
        int width = bitmap.getWidth();
        int height = bitmap.getHeight();
        float scaleWidth = ((float) size) / width;
        float scaleHeight = ((float) size) / height;
        Matrix matrix = new Matrix();
        matrix.postScale(scaleWidth, scaleHeight);
        return Bitmap.createBitmap(bitmap, 0, 0, width, height, matrix, true);
    }
}

运行效果

图片选择和拍照获取界面:

基于现有 TensorFlow 模型构建 Android 应用_第5张图片

物品识别结果展示界面:

基于现有 TensorFlow 模型构建 Android 应用_第6张图片

相关代码可查看:GitHub 项目地址

是不是觉得通过 TensorFlow 在现有的数据模型基础下,我们可以很简单就完成了一个简单的图像识别应用。

在使用这个模型来推断物品类型的过程中,发现好像有时候准确率不是那么高,这时候改怎么办。如果说只是想识别一些特定种类的物品,哪有又该怎么办?

在之前一篇文章中我有提到过,机器学习是依靠对大量有标签的样本数据进行反复训练后才逐步得到的最佳模型。对未知无标签样本的推断依赖这个模型的准确程度。所以我们可以通过对现有模型进行迁移训练(retrain)来定制我们自己的模型。

下面就通过对现有的 Google Inception-V3 模型进行 retrain ,对 5 种花朵样本数据的进行训练,来完成一个可以识别五种花朵的模型。

具体实现方式可以参考我的另外一篇文章:通过迁移训练来定制 TensorFlow 模型

你可能感兴趣的:(基于现有 TensorFlow 模型构建 Android 应用)