Yolo Tiny是 Yolo2的简化版,虽然有点过时但对于很多物体检测的应用场景还是很管用,本示例利用DeepLearing4j构建Yolo算法实现目标检测,下图是本示例的网络结构:
// parameters matching the pretrained TinyYOLO model
int width = 416;
int height = 416;
int nChannels = 3;
int gridWidth = 13;
int gridHeight = 13;
// number classes (digits) for the SVHN datasets
int nClasses = 5;
// parameters for the Yolo2OutputLayer
double[][] priorBoxes = { { 1.5, 2.2 }, { 1.4, 1.95 }, { 1.8, 3.3 }, { 2.4, 2.9 }, { 1.7, 2.2 } };
double detectionThreshold = 0.8;
// parameters for the training phase
int batchSize = 10;
int nEpochs = 20;
int seed = 123;
Random rng = new Random(seed);
File imageDir = new File("D:\\train");
InputSplit[] data = new FileSplit(imageDir, NativeImageLoader.ALLOWED_FORMATS, rng).sample(null, 0.9, 0.1);
InputSplit trainData = data[0];
InputSplit testData = data[1];
ObjectDetectionRecordReader recordReaderTrain = new ObjectDetectionRecordReader(height, width, nChannels, gridHeight, gridWidth,
new YoloLabelProvider(imageDir.getAbsolutePath()));
recordReaderTrain.initialize(trainData);
ObjectDetectionRecordReader recordReaderTest = new ObjectDetectionRecordReader(height, width, nChannels, gridHeight, gridWidth,
new YoloLabelProvider(imageDir.getAbsolutePath()));
recordReaderTest.initialize(testData);
// ObjectDetectionRecordReader performs regression, so we need to specify it here
RecordReaderDataSetIterator train = new RecordReaderDataSetIterator(recordReaderTrain, batchSize, 1, 1, true);
train.setPreProcessor(new ImagePreProcessingScaler(0, 1));
RecordReaderDataSetIterator test = new RecordReaderDataSetIterator(recordReaderTest, 1, 1, 1, true);
test.setPreProcessor(new ImagePreProcessingScaler(0, 1));
ComputationGraph model;
String modelFilename = "D:\\model.zip";
if (new File(modelFilename).exists()) {
this.output("Load model...");
model = ComputationGraph.load(new File(modelFilename), true);
} else {
this.output("Build model...");
model = TinyYOLO.builder().numClasses(nClasses).priorBoxes(priorBoxes).build().init();
System.out.println(model.summary(InputType.convolutional(height, width, nChannels)));
this.output("Train model...");
model.setListeners(new ScoreIterationListener(1));
model.fit(train, nEpochs);
ModelSerializer.writeModel(model, modelFilename, true);
}
// visualize results on the test set
NativeImageLoader imageLoader = new NativeImageLoader();
CanvasFrame frame = new CanvasFrame("WatermelonDetection");
OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer yout = (org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer) model
.getOutputLayer(0);
List labels = train.getLabels();
test.setCollectMetaData(true);
Scalar[] colormap = { RED, BLUE, GREEN, CYAN, YELLOW, MAGENTA, ORANGE, PINK, LIGHTBLUE, VIOLET };
while (test.hasNext() && frame.isVisible()) {
org.nd4j.linalg.dataset.DataSet ds = test.next();
RecordMetaDataImageURI metadata = (RecordMetaDataImageURI) ds.getExampleMetaData().get(0);
INDArray features = ds.getFeatures();
INDArray results = model.outputSingle(features);
List objs = yout.getPredictedObjects(results, detectionThreshold);
File file = new File(metadata.getURI());
Mat mat = imageLoader.asMat(features);
Mat convertedMat = new Mat();
mat.convertTo(convertedMat, CV_8U, 255, 0);
int w = metadata.getOrigW();
int h = metadata.getOrigH();
Mat image = new Mat();
resize(convertedMat, image, new Size(w, h));
for (DetectedObject obj : objs) {
double[] xy1 = obj.getTopLeftXY();
double[] xy2 = obj.getBottomRightXY();
String label = labels.get(obj.getPredictedClass());
int x1 = (int) Math.round(w * xy1[0] / gridWidth);
int y1 = (int) Math.round(h * xy1[1] / gridHeight);
int x2 = (int) Math.round(w * xy2[0] / gridWidth);
int y2 = (int) Math.round(h * xy2[1] / gridHeight);
rectangle(image, new Point(x1, y1), new Point(x2, y2), colormap[obj.getPredictedClass()]);
putText(image, label, new Point(x1 + 2, y2 - 2), FONT_HERSHEY_DUPLEX, 1, colormap[obj.getPredictedClass()]);
}
frame.setTitle(new File(metadata.getURI()).getName() + " - WatermelonDetection");
frame.setCanvasSize(w, h);
frame.showImage(converter.convert(image));
frame.waitKey();
}
frame.dispose();
参数讲解
图片的宽高 :int width = 416; int height = 416;是固定的
图片的通道数彩色 是int nChannels = 3;灰图则是nChannels=1,默认为3
算法的特征提取框的个数,yolo tiny 默认个数为13 不能改变 int gridWidth = 13; int gridHeight = 13;
待检测的类别个数,我这示例是5 个 int nClasses = 5
特征提取先验框的横高比 double[][] priorBoxes = { { 1.5, 2.2 }, { 1.4, 1.95 }, { 1.8, 3.3 }, { 2.4, 2.9 }, { 1.7, 2.2 } };Yolo2中提取先验框需通过Kmeans函数,代码如下
YoloLabelProvider svhnLabelProvider = new YoloLabelProvider(trainDir.getAbsolutePath());
DistanceMeasure distanceMeasure = new YoloIOUDistanceMeasure();
KMeansPlusPlusClusterer clusterer = new KMeansPlusPlusClusterer<>(5, 15, distanceMeasure);
File[] pngFiles = trainDir.listFiles(new FilenameFilter() {
private final static String FILENAME_SUFFIX = ".png";
@Override
public boolean accept(File dir, String name) {
return name.endsWith(FILENAME_SUFFIX);
}
});
List clusterInput = Stream.of(pngFiles).flatMap(png -> svhnLabelProvider.getImageObjectsForPath(png.getName()).stream())
.map(imageObject -> new ImageObjectWrapper(imageObject)).filter(imageObjectWraper -> {
double[] point = imageObjectWraper.getPoint();
if (point[0] <= 32d && point[1] <= 32) {//少于一个单元格的不计
return false;
}
return true;
}).collect(Collectors.toList());
List> clusterResults = clusterer.cluster(clusterInput);
for (int i = 0; i < clusterResults.size(); i++) {
CentroidCluster centroidCluster = clusterResults.get(i);
double[] point = centroidCluster.getCenter().getPoint();
System.out.println(
"width:" + point[0] + " height:" + point[1] + " ratio:" + point[1] / point[0] + " size:" + centroidCluster.getPoints().size());
System.out.println("bbox amount:" + point[0] / 32 + "," + point[1] / 32);
ImageObjectWrapper maxWidthImage = centroidCluster.getPoints().stream()
.collect(Collectors.maxBy(Comparator.comparingDouble(ImageObjectWrapper::getWidth))).get();
ImageObjectWrapper maxHeightImage = centroidCluster.getPoints().stream()
.collect(Collectors.maxBy(Comparator.comparingDouble(ImageObjectWrapper::getHeight))).get();
System.out.println(" width:" + maxWidthImage.getWidth() + " height:" + maxHeightImage.getHeight());
System.out.println("-----------");
}
上述主要通过Kmeas方法获取训练样本中有代表性的宽高比,需要重新Kmeas的距离测算的方法,改成IOU的形式具体可参照YOLO v2目标检测详解二 计算iou - 灰信网(软件开发博客聚合)
detectionThreshold 是物体检测的置信度阀值,值越高检测出来的物体个数越小,准确率越高
我的训练集是通过LabelImg制作且格式为Yolo,训练样本如下,注意图片的大小要与参数416x416的大小一致
标签类别文件为classes.txt ,包括五个类别xi ,cake ,dan,ss,bi
标签解释提供类YoloLabelProvider代码如下,主要作用是把LabelImg制作出来的txt的数据转化成算法可以识别的
public class YoloLabelProvider implements ImageObjectLabelProvider {
private String baseDirectory;
private List labels;
public YoloLabelProvider(String baseDirectory) {
this.baseDirectory = baseDirectory;
Assert.notNull(baseDirectory, "标签目录不能为空");
if (!new File(baseDirectory).exists()) {
throw new IllegalStateException(
"baseDirectory directory does not exist. txt files should be " + "present at Expected location: " + baseDirectory);
}
String classTxtPath = FilenameUtils.concat(this.baseDirectory, "classes.txt");
File classFile = new File(classTxtPath);
Assert.isTrue(classFile.exists(), "classTxtPath does not exist");
try {
labels = Files.readAllLines(classFile.toPath());
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public List getImageObjectsForPath(String path) {
int idx = path.lastIndexOf('/');
idx = Math.max(idx, path.lastIndexOf('\\'));
String filename = path.substring(idx + 1, path.length() - 4); //-4: ".png"
String txtPath = FilenameUtils.concat(this.baseDirectory, filename + ".txt");
String pngPath = FilenameUtils.concat(this.baseDirectory, filename + ".png");
File txtFile = new File(txtPath);
if (!txtFile.exists()) {
throw new IllegalStateException("Could not find TXT file for image " + path + "; expected at " + txtPath);
}
List readAllLines = null;
BufferedImage image = null;
try {
image = ImageIO.read(Paths.get(pngPath).toFile());
readAllLines = Files.readAllLines(txtFile.toPath());
} catch (Exception e) {
throw new RuntimeException(e);
}
int width = image.getWidth();
int height = image.getHeight();
List imageObjects = readAllLines.stream().map(line -> {
String[] data = StringUtils.split(line, " ");
int centerX = Math.round(Float.valueOf(data[1]) * width);
int centerY = Math.round(Float.valueOf(data[2]) * height);
int bboxWidth = Math.round(Float.valueOf(data[3]) * width);
int bboxHeight = Math.round(Float.valueOf(data[4]) * height);
int xmin = centerX - (bboxWidth / 2);
int ymin = centerY - (bboxHeight / 2);
int xmax = centerX + (bboxWidth / 2);
int ymax = centerY + (bboxHeight / 2);
ImageObject imageObject = new ImageObject(xmin, ymin, xmax, ymax, this.labels.get(Integer.valueOf(data[0])));
return imageObject;
}).collect(Collectors.toList());
return imageObjects;
}
@Override
public List getImageObjectsForPath(URI uri) {
return getImageObjectsForPath(uri.toString());
}
}