Android端使用TensorFlow进行图像分类

      最近一直在看TensorFlow的视频教程,它是Google的一个机器学习的跨平台开源框架,可以移植到Android、ios等移动端设备运行。GitHub上面有许多关于TensorFlow开发demo,这让极客开发者们心情躁动,跃跃欲试。今天主要与大家探讨下运用TensorFlow进行图片分类。

      它实现的图片分类效果还是相当准确的,值得点赞,先上一张宠物猫的识别:

Android端使用TensorFlow进行图像分类_第1张图片

      TensorFlow移植到Android端非常方便,gradle导入依赖,结合训练好的样本库和标签关键词(它俩存放于assets目录),把图片Bitmap转成ByteBuffer传进去,调用解析器Interprerter的run方法,就可以获取结果。它的label标签是这样定义(其中有一个tiger cat就是上图识别结果):

Android端使用TensorFlow进行图像分类_第2张图片

       gradle依赖:

compile 'org.tensorflow:tensorflow-lite:0.1.7'

       首先初始化图像分类器:

    /**
     * 初始化TensorFlow图像分类器
     * @param context 上下文
     * @param inputImageWidth 图像宽度
     * @param inputImageHeight 图像高度
     */
    public TensorFlowImageClassifier(Context context, int inputImageWidth, int inputImageHeight)
            throws IOException {
        this.tfLite = new Interpreter(TensorFlowHelper.loadModelFile(context, MODEL_FILE));
        this.labels = TensorFlowHelper.readLabels(context, LABELS_FILE);

        imgData = ByteBuffer.allocateDirect(
                        DIM_BATCH_SIZE * inputImageWidth * inputImageHeight * DIM_PIXEL_SIZE);
        imgData.order(ByteOrder.nativeOrder());
        confidencePerLabel = new byte[1][labels.size()];

        // Pre-allocate buffer for image pixels.
        intValues = new int[inputImageWidth * inputImageHeight];
    }

       其中加载模型文件和标签文件过程如下:

    /**
     * 加载模型文件
     * @param context 上下文
     * @param modelFile 模型文件
     * @return 映射的字节缓存
     */
    public static MappedByteBuffer loadModelFile(Context context, String modelFile)
            throws IOException {

        AssetFileDescriptor fileDescriptor = context.getAssets().openFd(modelFile);
        FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
        FileChannel fileChannel = inputStream.getChannel();
        long startOffset = fileDescriptor.getStartOffset();
        long declaredLength = fileDescriptor.getDeclaredLength();
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
    }

    /**
     * 加载标签文件
     * @param context 上下文
     * @param labelsFile 标签文件
     * @return 标签列表
     */
    public static List readLabels(Context context, String labelsFile) {
        AssetManager assetManager = context.getAssets();
        ArrayList result = new ArrayList<>();
        try (InputStream is = assetManager.open(labelsFile);
             BufferedReader br = new BufferedReader(new InputStreamReader(is))) {
            String line;
            while ((line = br.readLine()) != null) {
                result.add(line);
            }
            return result;
        } catch (IOException ex) {
            throw new IllegalStateException("Cannot read labels from " + labelsFile);
        }
    }

       图像Bitmap需要转换成ByteBuffer:

   /**
     * Bitmap转成ByteBuffer
     */
    public static void convertBitmapToByteBuffer(Bitmap bitmap, int[] intValues, ByteBuffer imgData) {
        if (imgData == null) {
            return;
        }
        imgData.rewind();
        bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0,
                bitmap.getWidth(), bitmap.getHeight());
        // Encode the image pixels into a byte buffer representation matching the expected
        // input of the Tensorflow model
        int pixel = 0;
        for (int i = 0; i < bitmap.getWidth(); ++i) {
            for (int j = 0; j < bitmap.getHeight(); ++j) {
                final int val = intValues[pixel++];
                imgData.put((byte) ((val >> 16) & 0xFF));
                imgData.put((byte) ((val >> 8) & 0xFF));
                imgData.put((byte) (val & 0xFF));
            }
        }
    }

       令人兴奋的时刻来了,调用解析器的run方法进行图像识别、获取分类结果:

    /**
     * 执行图像识别,返回分类结果集
     * @param image Bitmap
     */
    public Collection doRecognize(Bitmap image) {
        TensorFlowHelper.convertBitmapToByteBuffer(image, intValues, imgData);

        long startTime = SystemClock.uptimeMillis();
        // Here's where the magic happens!!!
        tfLite.run(imgData, confidencePerLabel);
        long endTime = SystemClock.uptimeMillis();
        Log.d(TAG, "Timecost to run model inference: " + Long.toString(endTime - startTime));

        // Get the results with the highest confidence and map them to their labels
        return TensorFlowHelper.getBestResults(confidencePerLabel, labels);
    }

       获取最佳结果集过程就是,通过优先级队列对结果的置信度从高到底进行排序:

    /**
     * 获取最佳的结果集
      */
    public static Collection getBestResults(byte[][] labelProbArray,
                                                         List labelList) {
        PriorityQueue sortedLabels = new PriorityQueue<>(RESULTS_TO_SHOW,
                new Comparator() {
                    @Override
                    public int compare(Recognition lhs, Recognition rhs) {
                        return Float.compare(lhs.getConfidence(), rhs.getConfidence());
                    }
                });


        for (int i = 0; i < labelList.size(); ++i) {
            Recognition r = new Recognition( String.valueOf(i),
                    labelList.get(i), (labelProbArray[0][i] & 0xff) / 255.0f);
            sortedLabels.add(r);
            if (r.getConfidence() > 0) {
                Log.d("ImageRecognition", r.toString());
            }
            if (sortedLabels.size() > RESULTS_TO_SHOW) {
                sortedLabels.poll();
            }
        }

        List results = new ArrayList<>(RESULTS_TO_SHOW);
        for (Recognition r: sortedLabels) {
            results.add(0, r);
        }

        return results;
    }

       其他的图像分类结果如下:

Android端使用TensorFlow进行图像分类_第3张图片

Android端使用TensorFlow进行图像分类_第4张图片

Android端使用TensorFlow进行图像分类_第5张图片

       最后,推荐大家学习的视频教程,是一位Google中国的机器学习开发者进行演讲:点击打开链接。另外,还有TensorFlow中文社区:http://www.tensorfly.cn/。Google提供的图像分类demo:https://github.com/androidthings/sample-tensorflow-imageclassifier。本篇文章demo基于Google提供基础上修改:图像分类demo

你可能感兴趣的:(android开发,机器学习)