opencv学习笔记六十二:MobileNet实现对象检测

MobileNet是SSD模型的精简版,速度更快,因为SSD检测一百多个对象,而MobileNet只检测20种物体,速度比较快,可以实时检测。

mobilenet模型下载:https://github.com/weiliu89/caffe/tree/ssd#models,下载PASCAL VOC models下的第一个,其里面包含模型文件和描述文件,对于类别文件,可以打开opencv里的例程

E:\anzhuang\opencv3.4.1\opencv\sources\samples\dnn\ssd_mobilenet_object_detection.cpp,例程里列出了这20种对象,在此程序上进行修改如下:

#include
#include
using namespace cv;
using namespace std;
using namespace dnn;

const size_t inWidth = 300;
const size_t inHeight = 300;
const float inScaleFactor = 0.007843f;
const float meanVal = 127.5;
const char* classNames[] = { "background",
"aeroplane", "bicycle", "bird", "boat",
"bottle", "bus", "car", "cat", "chair",
"cow", "diningtable", "dog", "horse",
"motorbike", "person", "pottedplant",
"sheep", "sofa", "train", "tvmonitor" };


int main(int argc, char** argv)
{	
	namedWindow("input", CV_WINDOW_AUTOSIZE);
	Mat src = imread("1.jpg");
	imshow("input",src);
	String modelConfiguration = "deploy.prototxt";
	String modelBinary = "VGG_VOC0712_SSD_300x300_iter_120000.caffemodel";
	//读入模型和描述文件
	Net net = readNetFromCaffe(modelConfiguration, modelBinary);	
	//将图像转为网络输入的模式
	Mat inputBlob = blobFromImage(src, inScaleFactor, Size(inWidth, inHeight), Scalar(meanVal, meanVal, meanVal), false, false);
	//输入网络							 					
	net.setInput(inputBlob,"data"); 		
	//前向传播	
	Mat detection = net.forward("detection_out"); 								 
	Mat detectionMat(detection.size[2], detection.size[3], CV_32F, detection.ptr());
	//设置置信度阈值
	float confidenceThreshold = 0.06;
		
		for (int i = 0; i < detectionMat.rows; i++)
		{
			float confidence = detectionMat.at(i, 2);

			if (confidence > confidenceThreshold)
			{
				size_t objectClass = (size_t)(detectionMat.at(i, 1));

				int tl_x = static_cast(detectionMat.at(i, 3) * src.cols);
				int tl_y = static_cast(detectionMat.at(i, 4) * src.rows);
				int br_x = static_cast(detectionMat.at(i, 5) * src.cols);
				int br_y = static_cast(detectionMat.at(i, 6) * src.rows);

				rectangle(src, Point(tl_x, tl_y), Point(br_x, br_y), Scalar(0, 255, 0));
				String label = format("%s: %.2f", classNames[objectClass], confidence);
		
				putText(src, label, Point(tl_x, tl_y), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0));
			}
		}
		imshow("detections", src);
		waitKey(0);
		return 0;	
	}

 opencv学习笔记六十二:MobileNet实现对象检测_第1张图片

你可能感兴趣的:(opencv)