这篇博客主要是介绍如何通过djl在java中调用yolov5进行推理,顺便也学习了一下在java上的opencv api。
Deep Java Library是由亚马逊(Amazon)提供的一个深度学习工具包,能够让java开发者在java上调用目前主流的深度学习框架,像pytorch、tensorflow、mxnet、paddlepaddle(飞桨居然也有份),也包括onnx格式的模型。
这次demo就直接使用yolov5s的预训练模型。yolov5项目本身就自带了非常完善的模型导出脚本,yolov5的5.0发行版也比之前的版本完善很多。
yolov5的模型导出脚本是models/export.py文件,
导出之前需要设置一下
djl使用pytorch需要引入相关依赖
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0modelVersion>
<groupId>xyz.hyhygroupId>
<artifactId>TestAIartifactId>
<version>1.0-SNAPSHOTversion>
<properties>
<maven.compiler.source>8maven.compiler.source>
<maven.compiler.target>8maven.compiler.target>
<djl.version>0.11.0djl.version>
properties>
<dependencies>
<dependency>
<groupId>ai.djlgroupId>
<artifactId>apiartifactId>
<version>${djl.version}version>
dependency>
<dependency>
<groupId>ai.djl.pytorchgroupId>
<artifactId>pytorch-model-zooartifactId>
<version>${djl.version}version>
dependency>
<dependency>
<groupId>ai.djl.pytorchgroupId>
<artifactId>pytorch-engineartifactId>
<version>${djl.version}version>
<scope>runtimescope>
dependency>
<dependency>
<groupId>ai.djl.pytorchgroupId>
<artifactId>pytorch-native-autoartifactId>
<version>1.8.1version>
dependency>
dependencies>
project>
下载完会得到一个exe文件,实际只是个压缩包,解压后到build文件夹下将jar包和x64或x86文件夹下的dll文件一起复制到项目的lib文件夹下。dll文件根据自己系统是64位还是32位进行选择。
将之前导出的yolov5s.torchscript.pt文件放到resources/yolov5s文件夹下。另外还要编写一个coco.names文件,用来说明分类任务的类名。
coco.names
person
bicycle
car
motorbike
aeroplane
bus
train
truck
boat
traffic light
fire hydrant
stop sign
parking meter
bench
bird
cat
dog
horse
sheep
cow
elephant
bear
zebra
giraffe
backpack
umbrella
handbag
tie
suitcase
frisbee
skis
snowboard
sports ball
kite
baseball bat
baseball glove
skateboard
surfboard
tennis racket
bottle
wine glass
cup
fork
knife
spoon
bowl
banana
apple
sandwich
orange
broccoli
carrot
hot dog
pizza
donut
cake
chair
sofa
pottedplant
bed
diningtable
toilet
tvmonitor
laptop
mouse
remote
keyboard
cell phone
microwave
oven
toaster
sink
refrigerator
book
clock
vase
scissors
teddy bear
hair drier
toothbrush
package xyz.hyhy;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.DetectedObjects.DetectedObject;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.translator.YoloV5Translator;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import org.opencv.core.*;
import org.opencv.highgui.HighGui;
import org.opencv.imgproc.Imgproc;
import org.opencv.videoio.VideoCapture;
import xyz.hyhy.utils.MyUtils;
import java.io.IOException;
import static org.opencv.videoio.Videoio.CAP_ANY;
public class Main {
static {
System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
}
public static void main(String[] args) {
Translator<Image, DetectedObjects> translator = YoloV5Translator.builder().optSynsetArtifactName("coco.names").build();
Criteria<Image, DetectedObjects> criteria =
Criteria.builder()
.setTypes(Image.class, DetectedObjects.class)
.optDevice(Device.cpu())
.optModelUrls(Main.class.getResource("/yolov5s").getPath())
.optModelName("yolov5s.torchscript.pt")
.optTranslator(translator)
.optEngine("PyTorch")
.build();
// Criteria criteria =
// Criteria.builder()
// .setTypes(Image.class, DetectedObjects.class)
// .optDevice(Device.cpu())
// .optModelUrls(Main.class.getResource("/yolov5").getPath())
// .optModelName("yolov5s.onnx")
// .optTranslator(translator)
// .optEngine("OnnxRuntime")
// .build();
try (ZooModel<Image, DetectedObjects> model = ModelZoo.loadModel(criteria)) {
VideoCapture cap = new VideoCapture(CAP_ANY);
if (!cap.isOpened()) {//isOpened函数用来判断摄像头调用是否成功
System.out.println("Camera Error");//如果摄像头调用失败,输出错误信息
} else {
Mat frame = new Mat();//创建一个输出帧
boolean flag = cap.read(frame);//read方法读取摄像头的当前帧
while (flag) {
detect(frame, model);
HighGui.imshow("yolov5", frame);
HighGui.waitKey(20);
flag = cap.read(frame);
}
}
} catch (RuntimeException | ModelException | TranslateException | IOException e) {
e.printStackTrace();
}
}
static Rect rect = new Rect();
static Scalar color = new Scalar(0, 255, 0);
static void detect(Mat frame, ZooModel<Image, DetectedObjects> model) throws IOException, ModelNotFoundException, MalformedModelException, TranslateException {
Image img = MyUtils.mat2Image(frame);
long startTime = System.currentTimeMillis();
try (Predictor<Image, DetectedObjects> predictor = model.newPredictor()) {
DetectedObjects results = predictor.predict(img);
// System.out.println(results);
for (DetectedObject obj : results.<DetectedObject>items()) {
BoundingBox bbox = obj.getBoundingBox();
Rectangle rectangle = bbox.getBounds();
String showText = String.format("%s: %.2f", obj.getClassName(), obj.getProbability());
rect.x = (int) rectangle.getX();
rect.y = (int) rectangle.getY();
rect.width = (int) rectangle.getWidth();
rect.height = (int) rectangle.getHeight();
// 画框
Imgproc.rectangle(frame, rect, color, 2);
//画名字
Imgproc.putText(frame, showText,
new Point(rect.x, rect.y),
Imgproc.FONT_HERSHEY_COMPLEX,
rectangle.getWidth() / 200,
color);
}
}
System.out.println(String.format("%.2f", 1000.0 / (System.currentTimeMillis() - startTime)));
}
}
程序启动时,会卡住一段时间,不过不要慌,因为djl需要下载pytorch的动态链接库,下载的位置在%USERPROFILE%\.djl.ai\pytorch
目录下。可以看一下加速球的流量消耗或者到对应文件夹下确认是否有在下载。
下载的实际上就是libtorch里面的那些动态链接库。djl会根据你的系统自动选择下载合适的版本(应该)。
效果:
之后测试了onnx的yolov5s模型,onnx的推理速度更快,速度大概是torchscript的3倍。
public static Image mat2Image(Mat mat) {
return ImageFactory.getInstance().fromImage(HighGui.toBufferedImage(mat));
}