最近一直在看TensorFlow的视频教程,它是Google的一个机器学习的跨平台开源框架,可以移植到Android、ios等移动端设备运行。GitHub上面有许多关于TensorFlow开发demo,这让极客开发者们心情躁动,跃跃欲试。今天主要与大家探讨下运用TensorFlow进行图片分类。
它实现的图片分类效果还是相当准确的,值得点赞,先上一张宠物猫的识别:
TensorFlow移植到Android端非常方便,gradle导入依赖,结合训练好的样本库和标签关键词(它俩存放于assets目录),把图片Bitmap转成ByteBuffer传进去,调用解析器Interprerter的run方法,就可以获取结果。它的label标签是这样定义(其中有一个tiger cat就是上图识别结果):
gradle依赖:
compile 'org.tensorflow:tensorflow-lite:0.1.7'
首先初始化图像分类器:
/**
* 初始化TensorFlow图像分类器
* @param context 上下文
* @param inputImageWidth 图像宽度
* @param inputImageHeight 图像高度
*/
public TensorFlowImageClassifier(Context context, int inputImageWidth, int inputImageHeight)
throws IOException {
this.tfLite = new Interpreter(TensorFlowHelper.loadModelFile(context, MODEL_FILE));
this.labels = TensorFlowHelper.readLabels(context, LABELS_FILE);
imgData = ByteBuffer.allocateDirect(
DIM_BATCH_SIZE * inputImageWidth * inputImageHeight * DIM_PIXEL_SIZE);
imgData.order(ByteOrder.nativeOrder());
confidencePerLabel = new byte[1][labels.size()];
// Pre-allocate buffer for image pixels.
intValues = new int[inputImageWidth * inputImageHeight];
}
其中加载模型文件和标签文件过程如下:
/**
* 加载模型文件
* @param context 上下文
* @param modelFile 模型文件
* @return 映射的字节缓存
*/
public static MappedByteBuffer loadModelFile(Context context, String modelFile)
throws IOException {
AssetFileDescriptor fileDescriptor = context.getAssets().openFd(modelFile);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
/**
* 加载标签文件
* @param context 上下文
* @param labelsFile 标签文件
* @return 标签列表
*/
public static List readLabels(Context context, String labelsFile) {
AssetManager assetManager = context.getAssets();
ArrayList result = new ArrayList<>();
try (InputStream is = assetManager.open(labelsFile);
BufferedReader br = new BufferedReader(new InputStreamReader(is))) {
String line;
while ((line = br.readLine()) != null) {
result.add(line);
}
return result;
} catch (IOException ex) {
throw new IllegalStateException("Cannot read labels from " + labelsFile);
}
}
图像Bitmap需要转换成ByteBuffer:
/**
* Bitmap转成ByteBuffer
*/
public static void convertBitmapToByteBuffer(Bitmap bitmap, int[] intValues, ByteBuffer imgData) {
if (imgData == null) {
return;
}
imgData.rewind();
bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0,
bitmap.getWidth(), bitmap.getHeight());
// Encode the image pixels into a byte buffer representation matching the expected
// input of the Tensorflow model
int pixel = 0;
for (int i = 0; i < bitmap.getWidth(); ++i) {
for (int j = 0; j < bitmap.getHeight(); ++j) {
final int val = intValues[pixel++];
imgData.put((byte) ((val >> 16) & 0xFF));
imgData.put((byte) ((val >> 8) & 0xFF));
imgData.put((byte) (val & 0xFF));
}
}
}
令人兴奋的时刻来了,调用解析器的run方法进行图像识别、获取分类结果:
/**
* 执行图像识别,返回分类结果集
* @param image Bitmap
*/
public Collection doRecognize(Bitmap image) {
TensorFlowHelper.convertBitmapToByteBuffer(image, intValues, imgData);
long startTime = SystemClock.uptimeMillis();
// Here's where the magic happens!!!
tfLite.run(imgData, confidencePerLabel);
long endTime = SystemClock.uptimeMillis();
Log.d(TAG, "Timecost to run model inference: " + Long.toString(endTime - startTime));
// Get the results with the highest confidence and map them to their labels
return TensorFlowHelper.getBestResults(confidencePerLabel, labels);
}
获取最佳结果集过程就是,通过优先级队列对结果的置信度从高到底进行排序:
/**
* 获取最佳的结果集
*/
public static Collection getBestResults(byte[][] labelProbArray,
List labelList) {
PriorityQueue sortedLabels = new PriorityQueue<>(RESULTS_TO_SHOW,
new Comparator() {
@Override
public int compare(Recognition lhs, Recognition rhs) {
return Float.compare(lhs.getConfidence(), rhs.getConfidence());
}
});
for (int i = 0; i < labelList.size(); ++i) {
Recognition r = new Recognition( String.valueOf(i),
labelList.get(i), (labelProbArray[0][i] & 0xff) / 255.0f);
sortedLabels.add(r);
if (r.getConfidence() > 0) {
Log.d("ImageRecognition", r.toString());
}
if (sortedLabels.size() > RESULTS_TO_SHOW) {
sortedLabels.poll();
}
}
List results = new ArrayList<>(RESULTS_TO_SHOW);
for (Recognition r: sortedLabels) {
results.add(0, r);
}
return results;
}
其他的图像分类结果如下:
最后,推荐大家学习的视频教程,是一位Google中国的机器学习开发者进行演讲:点击打开链接。另外,还有TensorFlow中文社区:http://www.tensorfly.cn/。Google提供的图像分类demo:https://github.com/androidthings/sample-tensorflow-imageclassifier。本篇文章demo基于Google提供基础上修改:图像分类demo