续我的6月15号的博客,原本准备将我自己训练的YOLO V3的模型移植到手机上,但是尝试了几次都不成功,发现自己训练的模型,在转换成.pb文件之后,创建tensorflow接口总是失败,估计时我模型保存时有其他的问题,故想先移植一个官方demo能够运行的SSD模型。
话不多说,先上我的手机最终显示的效果图。
这个图是我用摄像头进行拍照,然后调用模型进行识别,将结果在原图上进行显示并传送到ImageView上,最终使用保存按钮,保存的图片。
我前期主要参考http://www.voidcn.com/article/p-rbnqjtim-brt.html,这篇中的第2 和第3部分
把训练好的pb文件放入Android项目中app/src/main/assets下,若不存在assets目录,右键main->new->Directory,输入assets。
我这个步骤主要是从tensorflow的官方demo中将ssd_mobilenet_v1_android_export.pb和coco_labels_list.txt copy过来放到了assets文件夹下面
将下载的libtensorflow_inference.so和libandroid_tensorflow_inference_java.jar如下结构放在libs文件夹下,这个依赖文件我会放到我的资源里,可以直接下载。
在defaultConfig中添加
multiDexEnabled true
ndk {
abiFilters "armeabi-v7a"
}
增加sourceSets
sourceSets {
main {
jniLibs.srcDirs = ['libs']
}
}
添加完之后的截图如下:
在dependencies中增加TensoFlow编译的jar文件libandroid_tensorflow_inference_java.jar:
implementation files('libs/libandroid_tensorflow_inference_java.jar')
到现在为止,build.gradle就配置完成了,接下来就是模型调用问题了。
新建接口和新建类一样的,这里不重复。
这个类是从官方demo中直接抄的,没有做任何的修改,所以直接复制粘贴就行了。
package com.example.mycamera;
import android.graphics.Bitmap;
import android.graphics.RectF;
import java.util.List;
public interface Classifier {
/**
* An immutable result returned by a Classifier describing what was recognized.
*/
public class Recognition {
/**
* A unique identifier for what has been recognized. Specific to the class, not the instance of
* the object.
*/
private final String id;
/**
* Display name for the recognition.
*/
private final String title;
/**
* A sortable score for how good the recognition is relative to others. Higher should be better.
*/
private final Float confidence;
/** Optional location within the source image for the location of the recognized object. */
private RectF location;
public Recognition(
final String id, final String title, final Float confidence, final RectF location) {
this.id = id;
this.title = title;
this.confidence = confidence;
this.location = location;
}
public String getId() {
return id;
}
public String getTitle() {
return title;
}
public Float getConfidence() {
return confidence;
}
public RectF getLocation() {
return new RectF(location);
}
public void setLocation(RectF location) {
this.location = location;
}
@Override
public String toString() {
String resultString = "";
if (id != null) {
resultString += "[" + id + "] ";
}
if (title != null) {
resultString += title + " ";
}
if (confidence != null) {
resultString += String.format("(%.1f%%) ", confidence * 100.0f);
}
if (location != null) {
resultString += location + " ";
}
return resultString.trim();
}
}
List recognizeImage(Bitmap bitmap);
void enableStatLogging(final boolean debug);
String getStatString();
void close();
}
创建的这个类主要的功能是:1.创建一个模型的接口,2.将需要识别的图片传入模型 3.将识别的结果从模型中取出,并返回最终结果。
//继承Classifier类的create功能
public static Classifier create(
final AssetManager assetManager,
final String modelFilename,
final String labelFilename,
final int inputSize) throws IOException {
final TFYoloV3Detector d = new TFYoloV3Detector();
InputStream labelsInput = null;
String actualFilename = labelFilename.split("file:///android_asset/")[1];
labelsInput = assetManager.open(actualFilename);
BufferedReader br = null;
br = new BufferedReader(new InputStreamReader(labelsInput));
String line;
while ((line = br.readLine()) != null) {
//LOGGER.w(line);
d.labels.add(line);
}
br.close();
d.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
final Graph g = d.inferenceInterface.graph();
d.inputName = "image_tensor";
final Operation inputOp = g.operation(d.inputName);
if (inputOp == null) {
throw new RuntimeException("Failed to find input Node '" + d.inputName + "'");
}
d.inputSize = inputSize;
final Operation outputOp1 = g.operation("detection_scores");
if (outputOp1 == null) {
throw new RuntimeException("Failed to find output Node 'detection_scores'");
}
final Operation outputOp2 = g.operation("detection_boxes");
if (outputOp2 == null) {
throw new RuntimeException("Failed to find output Node 'detection_boxes'");
}
final Operation outputOp3 = g.operation("detection_classes");
if (outputOp3 == null) {
throw new RuntimeException("Failed to find output Node 'detection_classes'");
}
// Pre-allocate buffers.
d.outputNames = new String[] {"detection_boxes", "detection_scores",
"detection_classes", "num_detections"};
d.intValues = new int[d.inputSize * d.inputSize];
d.byteValues = new byte[d.inputSize * d.inputSize * 3];
d.outputScores = new float[MAX_RESULTS];
d.outputLocations = new float[MAX_RESULTS * 4];
d.outputClasses = new float[MAX_RESULTS];
d.outputNumDetections = new float[1];
return d;
}
直接继承这个Classifier接口,可能会报错,只需要在报错的地方,点击显示的红色的小灯泡,然后就可以继承这个接口了。
public List recognizeImage(final Bitmap bitmap) {
//Bitmap bitmapResized = bitmapToFloatArray(bitmap,inputSize,inputSize);//需要将图片缩放带28*28
// Copy the input data into TensorFlow.
bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
for (int i = 0; i < intValues.length; ++i) {
byteValues[i * 3 + 2] = (byte) (intValues[i] & 0xFF);
byteValues[i * 3 + 1] = (byte) ((intValues[i] >> 8) & 0xFF);
byteValues[i * 3 + 0] = (byte) ((intValues[i] >> 16) & 0xFF);
}
//将需要识别的图片feed给模型
inferenceInterface.feed(inputName, byteValues, 1, inputSize, inputSize, 3);
inferenceInterface.run(outputNames, logStats);//运行模型
outputLocations = new float[MAX_RESULTS * 4];
outputScores = new float[MAX_RESULTS];
outputClasses = new float[MAX_RESULTS];
outputNumDetections = new float[1];
//将识别的结果取出来
inferenceInterface.fetch(outputNames[0], outputLocations);
inferenceInterface.fetch(outputNames[1], outputScores);
inferenceInterface.fetch(outputNames[2], outputClasses);
inferenceInterface.fetch(outputNames[3], outputNumDetections);
// Scale them back to the input size.
final ArrayList recognitions = new ArrayList();
for (int i = 0; i < (int)outputNumDetections[0]; ++i) {
final RectF detection =
new RectF(
outputLocations[4 * i + 1] * inputSize,
outputLocations[4 * i] * inputSize,
outputLocations[4 * i + 3] * inputSize,
outputLocations[4 * i + 2] * inputSize);
recognitions.add(
new Recognition("" + i, labels.get((int) outputClasses[i]), outputScores[i], detection));
}
/*final ArrayList recognitions = new ArrayList();
for (int i = 0; i <= Math.min(pq.size(), MAX_RESULTS); ++i) {
recognitions.add(pq.poll());
}*/
return recognitions;
}
将图片缩放至指定的大小:bitmap即为你想要输入模型的图片
Bitmap bitmapResized = bitmapResize(bitmap,YOLO_INPUT_SIZE,YOLO_INPUT_SIZE);//需要将图片缩放至416*416
其中bitmapResize的函数如下:
//将原图缩放到模型的指定输入大小,bitmap是原图,rx,ry是模型的输入图片大小
public static Bitmap bitmapResize(Bitmap bitmap, int rx, int ry){
int height = bitmap.getHeight();
int width = bitmap.getWidth();
// 计算缩放比例
float scaleWidth = ((float) rx) / width;
float scaleHeight = ((float) ry) / height;
Matrix matrix = new Matrix();
matrix.postScale(scaleWidth, scaleHeight);
bitmap = Bitmap.createBitmap(bitmap, 0, 0, width, height, matrix, true);
return bitmap;
}
计算原图和送入模型的图像缩放比:scaleimageX和scaleimageY的类型为float
scaleimageX=(float) (bitmap.getWidth()*1.0)/bitmapResized.getWidth();//计算原图和送入模型的缩放比例x方向
scaleimageY=(float)(bitmap.getHeight()*1.0)/bitmapResized.getHeight();//计算原图和送入模型的缩放比例x方向
final List results = detector.recognizeImage(bitmapResized);//取出识别的结果
首先设置画布和画笔的参数,然后计算模型识别结果到原图的映射,最终画出目标检测结果边界框、类别和概率。代码如下:
croppedBitmap = bitmap.copy(Bitmap.Config.ARGB_8888, true);//copy原图
final Canvas canvas = new Canvas(croppedBitmap);//创建一个新画布
final Paint paint = new Paint();//创建绘制
paint.setColor(Color.RED);//设置颜色
paint.setStyle(Paint.Style.STROKE);//创建绘制轮廓
paint.setStrokeWidth(5.0f);//设置画笔的宽度
final Paint paintText = new Paint();//创建字体
paintText.setColor(Color.RED);//设置颜色
paintText.setTextSize(80);//设置子图大小
float minimumConfidence = MINIMUM_CONFIDENCE_YOLO;
final List mappedRecognitions =
new LinkedList();
for (final Classifier.Recognition result : results) {
//还原边界框在原图的位置
final RectF location = new RectF(
result.getLocation().left *= scaleimageX,
result.getLocation().top *= scaleimageY,
result.getLocation().right *= scaleimageX,
result.getLocation().bottom *= scaleimageY);
//判断大于设置的置信度则将位置在原图上标记出来
if (location != null && result.getConfidence() >= minimumConfidence) {
//location[0]=location[0]*
canvas.drawRect(location, paint);//画边界框
canvas.drawText(result.getTitle()+" "+result.getConfidence(),
location.left,
location.top-10,
paintText);//将类别和概率显示在图上
//cropToFrameTransform.mapRect(location);
result.setLocation(location);
mappedRecognitions.add(result);
}
}
cameraPicture.setImageBitmap(croppedBitmap);//将最终的图显示在ImageVieView控件
到这里,调用模型的步骤就结束了,我主要是阅读tensorflow中的Android的demo,从里面抽取出我需要的功能,最终成功了。
耗时6天,完成了每天除了吃饭睡觉一直在干的事情。从刚开始的一头雾水不知道从何下手,最终完成了我所需要的模型调用功能。这次最大的收获是,知道了Android端调用深度学习模型的几个步骤,相当于我毕设的倒数第二章已经完成,下一步是查找正确的保存模型并正确的转换成.pb文件的方法,将自己训练的粮虫识别的模型移植到手机上。
自己的收获:还是要静下心阅读源码,不能急躁,一直想这抄现成的。
离最后的成功又近一步,加油自己!