OpenCV----PSPNet图像分割模型PASCAL_VOC推理

题目要求:基于PSPNet的图像分割模型,参考模型PSPNet_Pytorch

  • 结果展示
  • 示例代码
 
 #include 
#include 
#include 
#include 

using namespace cv;
using namespace dnn;
using namespace std;

class baseSeg
{
public:
	// load onnx model
	baseSeg(int h, int w, const string& model_path="model/pspnet.onnx"){
		this -> inHeight = h;
		this->inWidth = w;
		this -> net = readNet(model_path);
	};
	
	void seg(Mat& frame);
private:
	Net net;
	int inWidth;
	int inHeight;
	// load classes lable and colors label
	vector<Vec3b> readColors(const string& labelFile="model/pascal_voc.txt");
};

vector<Vec3b> baseSeg::readColors(const string& labelFile)
{
	vector<Vec3b> colors;
	ifstream fp(labelFile);
	if (!fp.is_open()) {
		printf("could not open the file...\n");
		exit(-1);
	}
	string line;
	while (!fp.eof()) {
		getline(fp, line);
		if (line.length()) {
			stringstream ss(line);
			string name;
			ss >> name;
			int temp;
			Vec3b color;
			ss >> temp;
			color[0] = (uchar)temp;
			ss >> temp;
			color[1] = (uchar)temp;
			ss >> temp;
			color[2] = (uchar)temp;
			colors.push_back(color);
		}
	}
	return colors;
}


void baseSeg::seg(Mat& frame){
	Mat blobImage = blobFromImage(frame, 1 / 255.0, Size(this ->inWidth, this->inHeight), Scalar(0, 0, 0), true, false);
	this -> net.setInput(blobImage);
	Mat score = net.forward();
	// segmentation and display
	const int rows = score.size[2];
	const int cols = score.size[3];
	const int chns = score.size[1];
	Mat maxCl(rows, cols, CV_8UC1);
	Mat maxVal(rows, cols, CV_32FC1);
	
	vector<Vec3b> colors = readColors();
	// setup LUT look up table
	for (int c = 0; c < chns; c++) {
		for (int row = 0; row < rows; row++) {
			const float *ptrScore = score.ptr<float>(0, c, row);
			uchar *ptrMaxCl = maxCl.ptr<uchar>(row);
			float *ptrMaxVal = maxVal.ptr<float>(row);
			for (int col = 0; col < cols; col++) {
				if(ptrScore[col] > ptrMaxVal[col]) {
					ptrMaxVal[col] = ptrScore[col];
					ptrMaxCl[col] = (uchar)c;
				}
			}
		}
	}

	// look up for colors
	Mat result = Mat::zeros(rows, cols, CV_8UC3);
	for (int row = 0; row < rows; row++) {
		const uchar *ptrMaxCl = maxCl.ptr<uchar>(row);
		Vec3b *ptrColor = result.ptr<Vec3b>(row);
		for (int col = 0; col < cols; col++) {
			ptrColor[col] = colors[ptrMaxCl[col]];
		}
	}
	Mat dst;
	namedWindow("input image", WINDOW_NORMAL);
	imshow("input image", frame);
	// fusion
	addWeighted(frame, 0.3, result, 0.7, 0, dst); 
	namedWindow("Fusion Result", WINDOW_NORMAL);
	imshow("Fusion Result", dst);
	waitKey(0);
}

int main(int argc, char** argv) {
	Mat frame = imread("inference/horses.jpg");
	if (frame.empty()) {
		printf("could not load image...\n");
		return -1;
	}
	int h = 473, w = 473;
	resize(frame, frame, Size(h, w));
	baseSeg net(h, w, "model/pspnet.onnx");
	net.seg(frame);
	return 0;
}

你可能感兴趣的:(opencv,目标检测,opencv,计算机视觉,人工智能)