OpenCV4学习笔记(69)——dnn模块之基于fast_style模型实现快速图像风格迁移

今天要整理记录的内容是fast_style模型在OpenCV中使用dnn模块进行加载调用,并基于该模型实现快速的图像风格迁移。在该博文中使用到的fast_style模型是基于pytorch框架训练而成的,总共分为九个不同风格的模型,其中的不同风格是在不同艺术作品中提取出来的,咋一看会有些眼熟,然而说不上来是什么画风。。。

具体的图像风格如下:

模型 风格(翻译可能有误)
candy 糖果
composition_vii 康丁斯基的抽象派绘画风格
feathers 羽毛
udnie 乌迪妮
la_muse 缪斯
mosaic 镶嵌
the_wave 海浪
starry_night 星夜
the_scream 爱德华·蒙克创作绘画风格(呐喊)

那在这里我们主要是整理fast_style模型在OpenCV中的调用,所以就不去钻研有关艺术范围的问题了。下面通过代码逐步整理记录fast_style模型的调用,注意这里以candy模型为例子,其他模型的调用方法也都是一样的。

首先,加载模型并且设置计算后台和目标设备

	string candy_model_path = "D:/opencv_c++/opencv_tutorial/data/models/fast_style/candy.t7";
	Net candy_net = readNetFromTorch(candy_model_path);
	candy_net.setPreferableBackend(DNN_BACKEND_OPENCV);
	candy_net.setPreferableTarget(DNN_TARGET_CPU);

然后读取要进行风格迁移的图像,将其转换为blob并传入网络的输入层,再进行前向传播

	Mat image = imread("D:/opencv_c++/opencv_tutorial/data/images/小公园.jpg");
	resize(image, image, Size(600, 600));
	Mat inputBlob = blobFromImage(image, 1.0, Size(256, 256), Scalar(103.939, 116.779, 123.68), false, false);
	candy_net.setInput(inputBlob);
	Mat prob = candy_net.forward();

得到前向传播结果矩阵prob后,就需要对其进行解码。这里得到的prob是一个四维的Mat对象,第一维度表示图像的索引,因为传入一张图像所以等于0;第二维度表示输出图像的通道;第三维度表示输出图像的高度;第四维度表示输出图像的宽度。所以我们需要的是第二、三、四个维度的信息,利用获得的输出图像信息来创建一个画布,这个画布就是用来接收风格迁移后的图像。

然后通过for循环来逐一获取prob中三个维度的值,并赋值到画布中相应位置处,从而得到风格迁移后的图像。

				int channels = prob.size[1];
				int height = prob.size[2];
				int width = prob.size[3];
				Mat result = Mat::zeros(Size(width, height), CV_32FC3);
				for (int ch = 0; ch < channels; ch++)
				{
					for (int row = 0; row < height; row++)
					{
						float* prob_ptr = prob.ptr<float>(0, ch, row);
						for (int col = 0; col < width; col++)
						{
							result.at<Vec3f>(row, col)[ch] = prob_ptr[col];
						}
					}
				}

要注意这里得到的风格迁移后的图像还不是最终的图像,因为网络的输出为了减少不同光照情况的影响而减去均值,所以我们需要给这个图像再加上均值,并且归一化到 [ 0 , 255 ] 之间,同时转换为CV_8UC3类型。

这时候我们就得到能显示的不同风格的图像了,但是由于风格迁移后,迁移图像和原图像的差距可能非常的大,甚至会存在噪声,所以我们对迁移图像进行轻微的中值滤波后,将迁移图像和原图像按照一定权重进行加权融合,以达到更好的显示效果。

	add(result, Scalar(103.939, 116.779, 123.68), result);
	normalize(result, result, 0, 255, NORM_MINMAX, CV_8UC3);
	medianBlur(result, result, 3);
	resize(result, result, image.size());
	addWeighted(result, 0.8, image, 0.2, 0.0, result);
	imshow("output", result);

到这里我们就实现了利用单独一个fast_style模型进行快速图像风格迁移的操作,上面提到总共分为九个不同风格的模型,所以我们可以全部都利用起来,实现一个多种风格迁移的效果。代码演示如下:

	string candy_model_path = "D:/opencv_c++/opencv_tutorial/data/models/fast_style/candy.t7";
	Net candy_net = readNetFromTorch(candy_model_path);
	candy_net.setPreferableBackend(DNN_BACKEND_OPENCV);
	candy_net.setPreferableTarget(DNN_TARGET_CPU);

	string composition_vii_model_path = "D:/opencv_c++/opencv_tutorial/data/models/fast_style/composition_vii.t7";
	Net composition_vii_net = readNetFromTorch(composition_vii_model_path);
	composition_vii_net.setPreferableBackend(DNN_BACKEND_OPENCV);
	composition_vii_net.setPreferableTarget(DNN_TARGET_CPU);

	string feathers_model_path = "D:/opencv_c++/opencv_tutorial/data/models/fast_style/feathers.t7";
	Net feathers_net = readNetFromTorch(feathers_model_path);
	feathers_net.setPreferableBackend(DNN_BACKEND_OPENCV);
	feathers_net.setPreferableTarget(DNN_TARGET_CPU);

	string la_muse_model_path = "D:/opencv_c++/opencv_tutorial/data/models/fast_style/la_muse.t7";
	Net la_muse_net = readNetFromTorch(la_muse_model_path);
	la_muse_net.setPreferableBackend(DNN_BACKEND_OPENCV);
	la_muse_net.setPreferableTarget(DNN_TARGET_CPU);

	string mosaic_model_path = "D:/opencv_c++/opencv_tutorial/data/models/fast_style/mosaic.t7";
	Net mosaic_net = readNetFromTorch(mosaic_model_path);
	mosaic_net.setPreferableBackend(DNN_BACKEND_OPENCV);
	mosaic_net.setPreferableTarget(DNN_TARGET_CPU);

	string starry_night_model_path = "D:/opencv_c++/opencv_tutorial/data/models/fast_style/starry_night.t7";
	Net starry_night_net = readNetFromTorch(starry_night_model_path);
	starry_night_net.setPreferableBackend(DNN_BACKEND_OPENCV);
	starry_night_net.setPreferableTarget(DNN_TARGET_CPU);

	string the_scream_model_path = "D:/opencv_c++/opencv_tutorial/data/models/fast_style/the_scream.t7";
	Net the_scream_net = readNetFromTorch(the_scream_model_path);
	the_scream_net.setPreferableBackend(DNN_BACKEND_OPENCV);
	the_scream_net.setPreferableTarget(DNN_TARGET_CPU);

	string the_wave_model_path = "D:/opencv_c++/opencv_tutorial/data/models/fast_style/the_wave.t7";
	Net the_wave_net = readNetFromTorch(the_wave_model_path);
	the_wave_net.setPreferableBackend(DNN_BACKEND_OPENCV);
	the_wave_net.setPreferableTarget(DNN_TARGET_CPU);

	string udnie_model_path = "D:/opencv_c++/opencv_tutorial/data/models/fast_style/udnie.t7";
	Net udnie_net = readNetFromTorch(udnie_model_path);
	udnie_net.setPreferableBackend(DNN_BACKEND_OPENCV);
	udnie_net.setPreferableTarget(DNN_TARGET_CPU);

	Mat image = imread("D:/opencv_c++/opencv_tutorial/data/images/小公园.jpg");
	resize(image, image, Size(600, 600));
	while (image.data != NULL)
	{
		imshow("input_image", image);
		Mat inputBlob = blobFromImage(image, 1.0, Size(256, 256), Scalar(103.939, 116.779, 123.68), false, false);
		char flag = cv::waitKey(1);
		switch (flag)
		{
			case '1':
			{
				candy_net.setInput(inputBlob);
				Mat prob = candy_net.forward();

				int channels = prob.size[1];
				int height = prob.size[2];
				int width = prob.size[3];
				Mat result = Mat::zeros(Size(width, height), CV_32FC3);
				for (int ch = 0; ch < channels; ch++)
				{
					for (int row = 0; row < height; row++)
					{
						float* prob_ptr = prob.ptr<float>(0, ch, row);
						for (int col = 0; col < width; col++)
						{
							result.at<Vec3f>(row, col)[ch] = prob_ptr[col];
						}
					}
				}
				add(result, Scalar(103.939, 116.779, 123.68), result);
				normalize(result, result, 0, 255, NORM_MINMAX, CV_8UC3);
				medianBlur(result, result, 3);
				resize(result, result, image.size());
				addWeighted(result, 0.8, image, 0.2, 0.0, result);
				imshow("output", result);
				break;
			}
			case '2':
			{
				composition_vii_net.setInput(inputBlob);
				Mat prob = composition_vii_net.forward();

				int channels = prob.size[1];
				int height = prob.size[2];
				int width = prob.size[3];
				Mat result = Mat::zeros(Size(width, height), CV_32FC3);
				for (int ch = 0; ch < channels; ch++)
				{
					for (int row = 0; row < height; row++)
					{
						float* prob_ptr = prob.ptr<float>(0, ch, row);
						for (int col = 0; col < width; col++)
						{
							result.at<Vec3f>(row, col)[ch] = prob_ptr[col];
						}
					}
				}
				add(result, Scalar(103.939, 116.779, 123.68), result);
				normalize(result, result, 0, 255, NORM_MINMAX, CV_8UC3);
				medianBlur(result, result, 3);
				resize(result, result, image.size());
				addWeighted(result, 0.8, image, 0.2, 0.0, result);
				imshow("output", result);
				break;
			}
			case '3':
			{
				feathers_net.setInput(inputBlob);
				Mat prob = feathers_net.forward();

				int channels = prob.size[1];
				int height = prob.size[2];
				int width = prob.size[3];
				Mat result = Mat::zeros(Size(width, height), CV_32FC3);
				for (int ch = 0; ch < channels; ch++)
				{
					for (int row = 0; row < height; row++)
					{
						float* prob_ptr = prob.ptr<float>(0, ch, row);
						for (int col = 0; col < width; col++)
						{
							result.at<Vec3f>(row, col)[ch] = prob_ptr[col];
						}
					}
				}
				add(result, Scalar(103.939, 116.779, 123.68), result);
				normalize(result, result, 0, 255, NORM_MINMAX, CV_8UC3);
				medianBlur(result, result, 3);
				resize(result, result, image.size());
				addWeighted(result, 0.8, image, 0.2, 0.0, result);
				imshow("output", result);
				break;
			}
			case '4':
			{
				la_muse_net.setInput(inputBlob);
				Mat prob = la_muse_net.forward();

				int channels = prob.size[1];
				int height = prob.size[2];
				int width = prob.size[3];
				Mat result = Mat::zeros(Size(width, height), CV_32FC3);
				for (int ch = 0; ch < channels; ch++)
				{
					for (int row = 0; row < height; row++)
					{
						float* prob_ptr = prob.ptr<float>(0, ch, row);
						for (int col = 0; col < width; col++)
						{
							result.at<Vec3f>(row, col)[ch] = prob_ptr[col];
						}
					}
				}
				add(result, Scalar(103.939, 116.779, 123.68), result);
				normalize(result, result, 0, 255, NORM_MINMAX, CV_8UC3);
				medianBlur(result, result, 3);
				resize(result, result, image.size());
				addWeighted(result, 0.8, image, 0.2, 0.0, result);
				imshow("output", result);
				break;
			}
			case '5':
			{
				mosaic_net.setInput(inputBlob);
				Mat prob = mosaic_net.forward();

				int channels = prob.size[1];
				int height = prob.size[2];
				int width = prob.size[3];
				Mat result = Mat::zeros(Size(width, height), CV_32FC3);
				for (int ch = 0; ch < channels; ch++)
				{
					for (int row = 0; row < height; row++)
					{
						float* prob_ptr = prob.ptr<float>(0, ch, row);
						for (int col = 0; col < width; col++)
						{
							result.at<Vec3f>(row, col)[ch] = prob_ptr[col];
						}
					}
				}
				add(result, Scalar(103.939, 116.779, 123.68), result);
				normalize(result, result, 0, 255, NORM_MINMAX, CV_8UC3);
				medianBlur(result, result, 3);
				resize(result, result, image.size());
				addWeighted(result, 0.8, image, 0.2, 0.0, result);
				imshow("output", result);
				break;
			}
			case '6':
			{
				starry_night_net.setInput(inputBlob);
				Mat prob = starry_night_net.forward();

				int channels = prob.size[1];
				int height = prob.size[2];
				int width = prob.size[3];
				Mat result = Mat::zeros(Size(width, height), CV_32FC3);
				for (int ch = 0; ch < channels; ch++)
				{
					for (int row = 0; row < height; row++)
					{
						float* prob_ptr = prob.ptr<float>(0, ch, row);
						for (int col = 0; col < width; col++)
						{
							result.at<Vec3f>(row, col)[ch] = prob_ptr[col];
						}
					}
				}
				add(result, Scalar(103.939, 116.779, 123.68), result);
				normalize(result, result, 0, 255, NORM_MINMAX, CV_8UC3);
				medianBlur(result, result, 3);
				resize(result, result, image.size());
				addWeighted(result, 0.8, image, 0.2, 0.0, result);
				imshow("output", result);
				break;
			}
			case '7':
			{
				the_scream_net.setInput(inputBlob);
				Mat prob = the_scream_net.forward();

				int channels = prob.size[1];
				int height = prob.size[2];
				int width = prob.size[3];
				Mat result = Mat::zeros(Size(width, height), CV_32FC3);
				for (int ch = 0; ch < channels; ch++)
				{
					for (int row = 0; row < height; row++)
					{
						float* prob_ptr = prob.ptr<float>(0, ch, row);
						for (int col = 0; col < width; col++)
						{
							result.at<Vec3f>(row, col)[ch] = prob_ptr[col];
						}
					}
				}
				add(result, Scalar(103.939, 116.779, 123.68), result);
				normalize(result, result, 0, 255, NORM_MINMAX, CV_8UC3);
				medianBlur(result, result, 3);
				resize(result, result, image.size());
				addWeighted(result, 0.8, image, 0.2, 0.0, result);
				imshow("output", result);
				break;
			}
			case '8':
			{
				the_wave_net.setInput(inputBlob);
				Mat prob = the_wave_net.forward();

				int channels = prob.size[1];
				int height = prob.size[2];
				int width = prob.size[3];
				Mat result = Mat::zeros(Size(width, height), CV_32FC3);
				for (int ch = 0; ch < channels; ch++)
				{
					for (int row = 0; row < height; row++)
					{
						float* prob_ptr = prob.ptr<float>(0, ch, row);
						for (int col = 0; col < width; col++)
						{
							result.at<Vec3f>(row, col)[ch] = prob_ptr[col];
						}
					}
				}
				add(result, Scalar(103.939, 116.779, 123.68), result);
				normalize(result, result, 0, 255, NORM_MINMAX, CV_8UC3);
				medianBlur(result, result, 3);
				resize(result, result, image.size());
				addWeighted(result, 0.8, image, 0.2, 0.0, result);
				imshow("output", result);
				break;
			}
			case '9':
			{
				udnie_net.setInput(inputBlob);
				Mat prob = udnie_net.forward();

				int channels = prob.size[1];
				int height = prob.size[2];
				int width = prob.size[3];
				Mat result = Mat::zeros(Size(width, height), CV_32FC3);
				for (int ch = 0; ch < channels; ch++)
				{
					for (int row = 0; row < height; row++)
					{
						float* prob_ptr = prob.ptr<float>(0, ch, row);
						for (int col = 0; col < width; col++)
						{
							result.at<Vec3f>(row, col)[ch] = prob_ptr[col];
						}
					}
				}
				add(result, Scalar(103.939, 116.779, 123.68), result);
				normalize(result, result, 0, 255, NORM_MINMAX, CV_8UC3);
				medianBlur(result, result, 3);
				resize(result, result, image.size());
				addWeighted(result, 0.8, image, 0.2, 0.0, result);
				imshow("output", result);
				break;
			}
			case ' ':
			{
				imshow("output", image);
				break;
			}
		default:
			break;
		}
		if (flag == 27)
		{
			break;
		}
	}

那实现的效果如下图所示:

candy效果

composition_vii效果
feathers效果

la_muse效果

mosaic效果

starry_night效果

the_scream效果

the_wave效果

udnie效果

上面就是全部九种图像风格啦,有兴趣的朋友可以自己做一下快速图像风格迁移哦,我感觉对于不同图像的迁移效果还是有些许区别的。好的那本次笔记就整理到此结束啦,谢谢阅读~(PS:话说有人知道上图是哪里的景点吗哈哈哈)

PS:本人的注释比较杂,既有自己的心得体会也有网上查阅资料时摘抄下的知识内容,所以如有雷同,纯属我向前辈学习的致敬,如果有前辈觉得我的笔记内容侵犯了您的知识产权,请和我联系,我会将涉及到的博文内容删除,谢谢!

你可能感兴趣的:(学习笔记)