本篇博文用来研究YOLOv5在Android上部署的例程
主要参考的是Pytorch官方提供的Demo:https://github.com/pytorch/android-demo-app/tree/master/PyTorchDemoApp
App主页如下图所示:
切换测试图片
在程序中直接指定三张(或任意张)图片,点击测试图片,可以切换图片
选择图片
点击选择图片,可以在相册中选择一张图片,也可以直接进行拍照
实时视频
点击实时视频,可以开启摄像头,直接在摄像预览中显示检测结果
切换模型(我添加的功能)
点击切换模型,可以选择不同的模型进行检测
首先来跑通官方Demo,首先下载官方提供的yolov5s.torchscript.ptl
下载链接:https://pytorch-mobile-demo-apps.s3.us-east-2.amazonaws.com/yolov5s.torchscript.ptl
下载完放到assets
文件夹下
直接运行,从相册中选择图片时会报错:
Unable to decode stream: java.io.FileNotFoundException:/…/open failed: EACCES (Permission denied)
此时需要在AndroidManifest.xml
的application
标签中添加一句:
android:requestLegacyExternalStorage="true"
然后就可以正常运行了
下面用YOLOv5-6.0版本训练自己的模型,怎么训练不做赘述,可以参考本专栏的往期博文。
然后修改export.py
中的export_torchscript
函数,主要添加三行代码,用以导出.torchscript.ptl
后缀模型。
def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
# YOLOv5 TorchScript model export
try:
print(f'\n{prefix} starting export with torch {torch.__version__}...')
f = file.with_suffix('.torchscript.pt')
f = str(f)
fl = file.with_suffix('.torchscript.ptl')
ts = torch.jit.trace(model, im, strict=False)
(optimize_for_mobile(ts) if optimize else ts).save(f)
(optimize_for_mobile(ts) if optimize else ts)._save_for_lite_interpreter(str(fl))
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
print(f'{prefix} export failure: {e}')
然后在终端运行:
python export.py --weights runs/train/exp/weights/best.pt --include torchscript
运行完得到best.torchscript.ptl
模型
下面来添加一个切换模型的功能,并使用自己训练的模型。
首先修改pytorch依赖版本,修改build.gradle
中的依赖:
implementation 'org.pytorch:pytorch_android_lite:1.9.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'
这里的版本尽量和后面训练用的pytorch版本对应,比如后面自己用的pytorch版本是1.9.0,这里就写1.9.0。
然后修改ObjectDetectionActivitys,java
,这里将mOutputColumn
的private
修饰符去掉,使其可以在外部访问:
接下来修改xml界面,在activity_main.xml
中添加切换模型按钮,并调整布局
<Button
android:id="@+id/select"
android:layout_width="100dp"
android:layout_height="wrap_content"
android:layout_marginTop="32dp"
android:textAllCaps="false"
android:text="@string/select_model"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintEnd_toStartOf="@+id/selectButton"
app:layout_constraintHorizontal_bias="0.5"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toBottomOf="@+id/detectButton"
android:background="@drawable/button_selector"/>
<Button
android:id="@+id/testButton"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_marginTop="180dp"
android:textAllCaps="false"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintHorizontal_bias="0.5"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toBottomOf="@+id/imageView"
android:background="@drawable/button_selector"/>
然后修改MainActivity.java
,添加以下三个属性
private String model_name = "yolov5s.torchscript.ptl";
private String model_class = "classes.txt";
private int num_class = 80;
添加选择模型按钮响应:
private void ShowChoise()
{
AlertDialog.Builder builder = new AlertDialog.Builder(MainActivity.this);
// builder.setIcon(R.drawable.ic_launcher_foreground);
builder.setTitle("选择一个模型");
// 指定下拉列表的显示数据
final String[] cities = {"YOLOv5s", "王者荣耀模型"};
// 设置一个下拉的列表选择项
builder.setItems(cities, new DialogInterface.OnClickListener()
{
@Override
public void onClick(DialogInterface dialog, int which)
{
Toast.makeText(MainActivity.this, "选择的模型为:" + cities[which], Toast.LENGTH_SHORT).show();
if (which==0){
model_name = "yolov5s.torchscript.ptl";
model_class = "classes.txt";
num_class = 80;
}
else {
model_name = "mymodel.ptl";
model_class = "classes_wzry.txt";
num_class = 10;
}
// 重新加载
try {
mModule = LiteModuleLoader.load(MainActivity.assetFilePath(getApplicationContext(), model_name));
BufferedReader br = new BufferedReader(new InputStreamReader(getAssets().open(model_class)));
String line;
List<String> classes = new ArrayList<>();
while ((line = br.readLine()) != null) {
classes.add(line);
}
PrePostProcessor.mClasses = new String[classes.size()];
PrePostProcessor.mOutputColumn = num_class + 5;
classes.toArray(PrePostProcessor.mClasses);
} catch (IOException e) {
Log.e("Object Detection", "Error reading assets", e);
finish();
}
}
});
builder.show();
}
这里选择的模型数量添加if分支,model_class
为模型对应的类别标签,需要仿照classes.txt
单独创建,num_class为类别数量。
最后将之上一步得到的best.torchscript.ptl
复制到assets
文件夹下,注意需要手动修改文件名mymodel.ptl
,这里不改名会发生文件找不到的报错,最后再运行即可。
除了上面这部分,还对界面进行了汉化,图片加载做了微调,几个修改过的文件的完整源码如下:
activity_main.xml
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context="org.pytorch.demo.objectdetection.MainActivity">
<ImageView
android:id="@+id/imageView"
android:layout_width="0dp"
android:layout_height="0dp"
android:layout_marginTop="0dp"
android:background="#FFFFFF"
android:contentDescription="@string/image_view"
app:layout_constraintDimensionRatio="1:1"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toTopOf="parent" />
<org.pytorch.demo.objectdetection.ResultView
android:id="@+id/resultView"
android:layout_width="0dp"
android:layout_height="0dp"
android:layout_marginTop="0dp"
app:layout_constraintDimensionRatio="1:1"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toTopOf="parent" />
<Button
android:id="@+id/detectButton"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_marginTop="20dp"
android:text="@string/detect"
android:textAllCaps="false"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintHorizontal_bias="0.498"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toBottomOf="@+id/imageView"
android:background="@drawable/button_selector"/>
<ProgressBar
android:id="@+id/progressBar"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_marginTop="20dp"
android:visibility="invisible"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintHorizontal_bias="0.498"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toBottomOf="@+id/imageView" />
<Button
android:id="@+id/selectButton"
android:layout_width="100dp"
android:layout_height="wrap_content"
android:text="@string/select"
android:textAllCaps="false"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintEnd_toStartOf="@+id/liveButton"
app:layout_constraintHorizontal_bias="0.5"
app:layout_constraintStart_toEndOf="@+id/select"
app:layout_constraintTop_toTopOf="@+id/select"
android:background="@drawable/button_selector"/>
<Button
android:id="@+id/liveButton"
android:layout_width="100dp"
android:layout_height="wrap_content"
android:text="@string/live"
android:textAllCaps="false"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintHorizontal_bias="0.5"
app:layout_constraintStart_toEndOf="@+id/selectButton"
app:layout_constraintTop_toTopOf="@+id/selectButton"
android:background="@drawable/button_selector"/>
<Button
android:id="@+id/select"
android:layout_width="100dp"
android:layout_height="wrap_content"
android:layout_marginTop="32dp"
android:textAllCaps="false"
android:text="@string/select_model"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintEnd_toStartOf="@+id/selectButton"
app:layout_constraintHorizontal_bias="0.5"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toBottomOf="@+id/detectButton"
android:background="@drawable/button_selector"/>
<Button
android:id="@+id/testButton"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_marginTop="180dp"
android:textAllCaps="false"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintHorizontal_bias="0.5"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toBottomOf="@+id/imageView"
android:background="@drawable/button_selector"/>
androidx.constraintlayout.widget.ConstraintLayout>
MainActivity.java
// Copyright (c) 2020 Facebook, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
package org.pytorch.demo.objectdetection;
import androidx.appcompat.app.AlertDialog;
import androidx.appcompat.app.AppCompatActivity;
import androidx.core.app.ActivityCompat;
import androidx.core.content.ContextCompat;
import android.Manifest;
import android.content.Context;
import android.content.DialogInterface;
import android.content.Intent;
import android.content.pm.PackageManager;
import android.database.Cursor;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.graphics.Matrix;
import android.net.Uri;
import android.os.Bundle;
import android.provider.MediaStore;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.ProgressBar;
import android.widget.Toast;
import org.pytorch.IValue;
import org.pytorch.LiteModuleLoader;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.List;
public class MainActivity extends AppCompatActivity implements Runnable {
private int mImageIndex = 0;
private String[] mTestImages = {"test1.png", "test2.jpg", "test3.png"};
private ImageView mImageView;
private ResultView mResultView;
private Button mButtonDetect;
private Button mButtonSelect;
private ProgressBar mProgressBar;
private Bitmap mBitmap = null;
private Module mModule = null;
private float mImgScaleX, mImgScaleY, mIvScaleX, mIvScaleY, mStartX, mStartY;
private String model_name = "yolov5s.torchscript.ptl";
private String model_class = "classes.txt";
private int num_class = 80;
public static String assetFilePath(Context context, String assetName) throws IOException {
File file = new File(context.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}
try (InputStream is = context.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
}
}
private void ShowChoise()
{
AlertDialog.Builder builder = new AlertDialog.Builder(MainActivity.this);
// builder.setIcon(R.drawable.ic_launcher_foreground);
builder.setTitle("选择一个模型");
// 指定下拉列表的显示数据
final String[] cities = {"YOLOv5s", "王者荣耀模型"};
// 设置一个下拉的列表选择项
builder.setItems(cities, new DialogInterface.OnClickListener()
{
@Override
public void onClick(DialogInterface dialog, int which)
{
Toast.makeText(MainActivity.this, "选择的模型为:" + cities[which], Toast.LENGTH_SHORT).show();
if (which==0){
model_name = "yolov5s.torchscript.ptl";
model_class = "classes.txt";
num_class = 80;
}
else {
model_name = "mymodel.ptl";
model_class = "classes_wzry.txt";
num_class = 10;
}
// 重新加载
try {
mModule = LiteModuleLoader.load(MainActivity.assetFilePath(getApplicationContext(), model_name));
BufferedReader br = new BufferedReader(new InputStreamReader(getAssets().open(model_class)));
String line;
List<String> classes = new ArrayList<>();
while ((line = br.readLine()) != null) {
classes.add(line);
}
PrePostProcessor.mClasses = new String[classes.size()];
PrePostProcessor.mOutputColumn = num_class + 5;
classes.toArray(PrePostProcessor.mClasses);
} catch (IOException e) {
Log.e("Object Detection", "Error reading assets", e);
finish();
}
}
});
builder.show();
}
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
if (ContextCompat.checkSelfPermission(this, Manifest.permission.READ_EXTERNAL_STORAGE) != PackageManager.PERMISSION_GRANTED) {
ActivityCompat.requestPermissions(this, new String[]{Manifest.permission.READ_EXTERNAL_STORAGE}, 1);
}
if (ContextCompat.checkSelfPermission(this, Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED) {
ActivityCompat.requestPermissions(this, new String[]{Manifest.permission.CAMERA}, 1);
}
setContentView(R.layout.activity_main);
try {
mBitmap = BitmapFactory.decodeStream(getAssets().open(mTestImages[mImageIndex]));
} catch (IOException e) {
Log.e("Object Detection", "Error reading assets", e);
finish();
}
mImageView = findViewById(R.id.imageView);
mImageView.setImageBitmap(mBitmap);
mResultView = findViewById(R.id.resultView);
mResultView.setVisibility(View.INVISIBLE);
final Button buttonTest = findViewById(R.id.testButton);
buttonTest.setText(("测试图片 1/3"));
buttonTest.setOnClickListener(new View.OnClickListener() {
public void onClick(View v) {
mResultView.setVisibility(View.INVISIBLE);
mImageIndex = (mImageIndex + 1) % mTestImages.length;
buttonTest.setText(String.format("测试图片 %d/%d", mImageIndex + 1, mTestImages.length));
try {
mBitmap = BitmapFactory.decodeStream(getAssets().open(mTestImages[mImageIndex]));
mImageView.setImageBitmap(mBitmap);
} catch (IOException e) {
Log.e("Object Detection", "Error reading assets", e);
finish();
}
}
});
final Button buttonSelect = findViewById(R.id.selectButton);
buttonSelect.setOnClickListener(new View.OnClickListener() {
public void onClick(View v) {
mResultView.setVisibility(View.INVISIBLE);
final CharSequence[] options = { "从相册选择", "拍照", "取消" };
AlertDialog.Builder builder = new AlertDialog.Builder(MainActivity.this);
builder.setTitle("新测试图片");
builder.setItems(options, new DialogInterface.OnClickListener() {
@Override
public void onClick(DialogInterface dialog, int item) {
if (options[item].equals("拍照")) {
Intent takePicture = new Intent(android.provider.MediaStore.ACTION_IMAGE_CAPTURE);
startActivityForResult(takePicture, 0);
}
else if (options[item].equals("从相册选择")) {
Intent pickPhoto = new Intent(Intent.ACTION_PICK, android.provider.MediaStore.Images.Media.INTERNAL_CONTENT_URI);
startActivityForResult(pickPhoto , 1);
}
else if (options[item].equals("取消")) {
dialog.dismiss();
}
}
});
builder.show();
}
});
final Button buttonLive = findViewById(R.id.liveButton);
buttonLive.setOnClickListener(new View.OnClickListener() {
public void onClick(View v) {
final Intent intent = new Intent(MainActivity.this, ObjectDetectionActivity.class);
startActivity(intent);
}
});
mButtonDetect = findViewById(R.id.detectButton);
mProgressBar = (ProgressBar) findViewById(R.id.progressBar);
mButtonDetect.setOnClickListener(new View.OnClickListener() {
public void onClick(View v) {
mButtonDetect.setEnabled(false);
mProgressBar.setVisibility(ProgressBar.VISIBLE);
mButtonDetect.setText(getString(R.string.run_model));
mImgScaleX = (float)mBitmap.getWidth() / PrePostProcessor.mInputWidth;
mImgScaleY = (float)mBitmap.getHeight() / PrePostProcessor.mInputHeight;
mIvScaleX = (mBitmap.getWidth() > mBitmap.getHeight() ? (float)mImageView.getWidth() / mBitmap.getWidth() : (float)mImageView.getHeight() / mBitmap.getHeight());
mIvScaleY = (mBitmap.getHeight() > mBitmap.getWidth() ? (float)mImageView.getHeight() / mBitmap.getHeight() : (float)mImageView.getWidth() / mBitmap.getWidth());
mStartX = (mImageView.getWidth() - mIvScaleX * mBitmap.getWidth())/2;
mStartY = (mImageView.getHeight() - mIvScaleY * mBitmap.getHeight())/2;
Thread thread = new Thread(MainActivity.this);
thread.start();
}
});
// 新增选择模型按钮
mButtonSelect = findViewById(R.id.select);
mButtonSelect.setOnClickListener(new View.OnClickListener() {
public void onClick(View v) {
ShowChoise();
}
});
try {
mModule = LiteModuleLoader.load(MainActivity.assetFilePath(getApplicationContext(), model_name));
BufferedReader br = new BufferedReader(new InputStreamReader(getAssets().open(model_class)));
String line;
List<String> classes = new ArrayList<>();
while ((line = br.readLine()) != null) {
classes.add(line);
}
PrePostProcessor.mClasses = new String[classes.size()];
PrePostProcessor.mOutputColumn = num_class;
classes.toArray(PrePostProcessor.mClasses);
} catch (IOException e) {
Log.e("Object Detection", "Error reading assets", e);
finish();
}
}
@Override
protected void onActivityResult(int requestCode, int resultCode, Intent data) {
super.onActivityResult(requestCode, resultCode, data);
if (resultCode != RESULT_CANCELED) {
switch (requestCode) {
case 0:
if (resultCode == RESULT_OK && data != null) {
mBitmap = (Bitmap) data.getExtras().get("data");
Matrix matrix = new Matrix();
//matrix.postRotate(90.0f);
matrix.postRotate(0);
mBitmap = Bitmap.createBitmap(mBitmap, 0, 0, mBitmap.getWidth(), mBitmap.getHeight(), matrix, true);
mImageView.setImageBitmap(mBitmap);
}
break;
case 1:
if (resultCode == RESULT_OK && data != null) {
Uri selectedImage = data.getData();
String[] filePathColumn = {MediaStore.Images.Media.DATA};
if (selectedImage != null) {
Cursor cursor = getContentResolver().query(selectedImage,
filePathColumn, null, null, null);
if (cursor != null) {
cursor.moveToFirst();
int columnIndex = cursor.getColumnIndex(filePathColumn[0]);
String picturePath = cursor.getString(columnIndex);
mBitmap = BitmapFactory.decodeFile(picturePath);
Matrix matrix = new Matrix();
//matrix.postRotate(90.0f);
matrix.postRotate(0);
mBitmap = Bitmap.createBitmap(mBitmap, 0, 0, mBitmap.getWidth(), mBitmap.getHeight(), matrix, true);
mImageView.setImageBitmap(mBitmap);
cursor.close();
}
}
}
break;
}
}
}
@Override
public void run() {
Bitmap resizedBitmap = Bitmap.createScaledBitmap(mBitmap, PrePostProcessor.mInputWidth, PrePostProcessor.mInputHeight, true);
final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(resizedBitmap, PrePostProcessor.NO_MEAN_RGB, PrePostProcessor.NO_STD_RGB);
IValue[] outputTuple = mModule.forward(IValue.from(inputTensor)).toTuple();
final Tensor outputTensor = outputTuple[0].toTensor();
final float[] outputs = outputTensor.getDataAsFloatArray();
final ArrayList<Result> results = PrePostProcessor.outputsToNMSPredictions(outputs, mImgScaleX, mImgScaleY, mIvScaleX, mIvScaleY, mStartX, mStartY);
runOnUiThread(() -> {
mButtonDetect.setEnabled(true);
mButtonDetect.setText(getString(R.string.detect));
mProgressBar.setVisibility(ProgressBar.INVISIBLE);
mResultView.setResults(results);
mResultView.invalidate();
mResultView.setVisibility(View.VISIBLE);
});
}
}
strings.xml
<resources>
<string name="app_name">YOLOv5string>
<string name="image_view">Image Viewstring>
<string name="detect">检测string>
<string name="run_model">正在运行,请稍后string>
<string name="restart">Restartstring>
<string name="select">选择图片string>
<string name="live">实时视频string>
<string name="select_model">切换模型string>
resources>
button_selector.xml
<selector xmlns:android="http://schemas.android.com/apk/res/android">
<item android:state_pressed="true">
<shape>
<solid android:color="#64AFFA"/>
<corners android:radius="10dp"/>
<padding
android:bottom="2dp"
android:left="3dp"
android:right="3dp"
android:top="2dp">
padding>
shape>
item>
<item android:state_pressed="false">
<shape>
<solid android:color="#99CCFF"/>
<corners android:radius="10dp"/>
<padding
android:bottom="2dp"
android:left="3dp"
android:right="3dp"
android:top="2dp">
padding>
shape>
item>
selector>
经过实测,整个APK文件打包出来有1点多G,由此可见pytorch框架一加进去体积就会变得很大,后续轻量化还有研究空间。同时,视频实时检测,帧率很低,基本卡成PPT,可能是受限于手机的算力不足,后续也有待研究优化。