上一章《训练模型调用及转换》把训练好的模型转换为c++可以通过libTorch调用的模型。想了解windows C++调用libTorch可以看《C++(libTorch)调用pytroch预训练模型》,本章不在介绍。
首先调用torch::jit::load()把训练模型加载进来。 预测图像是使用opencv的cv::imread()进行读取。然后通过torch::from_blob转化为torch张量。
接下来把两张图像张量进行合并进行预测。输出张量拆分后进行光流处理。最后把处理的张量连同两张原始图张量和光流张量合并后调用第二个模型进行预测。把预测过的结果再次拆分,分别进行光流处理后合并成最终结果。以下是C++代码:
std::vector
cv::Mat *pImage1, *pImage2;
pImage1 = (*pMats)[0];
pImage2 = (*pMats)[1];
CTDTorchJitModule *pTorchJitModule = (CTDTorchJitModule*)pModule;
torch::jit::script::Module* pScriptModule1 = pTorchJitModule->GetModule1();
torch::jit::script::Module* pScriptModule2 = pTorchJitModule->GetModule2();
pScriptModule1->to(pTorchJitModule->GetDeviceType());
pScriptModule2->to(pTorchJitModule->GetDeviceType());
int w, h;
w = pImage1->cols / 32 * 32;
h = pImage1->rows / 32 * 32;
pImage1->resize((w, h));
pImage2->resize((w, h));
//std::vector
std::vector<int64_t> sizes = { 1, pImage1->rows, pImage1->cols };
at::Tensor tensor_image1 = torch::from_blob(pImage1->data, at::IntList(sizes), at::ScalarType::Byte).to(pTorchJitModule->GetDeviceType()).unsqueeze(0) / 255.0;
//tensor_image1 = tensor_image1.permute({ 0,3,1,2 });
tensor_image1 = tensor_image1.toType(at::kFloat);
at::Tensor tensor_image2 = torch::from_blob(pImage2->data, at::IntList(sizes), at::ScalarType::Byte).to(pTorchJitModule->GetDeviceType()).unsqueeze(0) / 255.0;
//tensor_image2 = tensor_image2.permute({ 0,3,1,2 });
tensor_image2 = tensor_image2.toType(at::kFloat);
vector<Tensor> vecTensor;
vecTensor.push_back(tensor_image1);
vecTensor.push_back(tensor_image2);
TensorList tl(vecTensor);
at::Tensor tensor_image = torch::cat(tl, 1);
at::Tensor output = pScriptModule1->forward({ tensor_image }).toTensor();
double t = 0.5;
double temp = -t * (1 - t);
double co_eff[4];
co_eff[0] = temp;
co_eff[1] = t * t;
co_eff[2] = (1 - t) * (1 - t);
co_eff[3] = temp;
at::Tensor f01 = output.slice(1, 0, 2, 1);
at::Tensor f10 = output.slice(1, 2, 4, 1);
at::Tensor ft0 = co_eff[0] * f01 + co_eff[1] * f10;
at::Tensor ft1 = co_eff[2] * f01 + co_eff[3] * f10;
at::Tensor u = ft0.slice(1, 0, 1, 1).squeeze(0);
at::Tensor v = ft0.slice(1, 1, 2, 1).squeeze(0);
at::Tensor gridX = pTorchJitModule->GetGridTensor(0)->to(pTorchJitModule->GetDeviceType()).expand_as(u);
at::Tensor gridY = pTorchJitModule->GetGridTensor(1)->to(pTorchJitModule->GetDeviceType()).expand_as(v);
at::Tensor x = gridX + u;
at::Tensor y = gridY + v;
x = 2 * (x / w - 0.5);
y = 2 * (y / h - 0.5);
at::Tensor grid = torch::stack({ x, y }, 3);
at::Tensor gi0ft0 = torch::grid_sampler(tensor_image1, grid, 0, 0, false);
u = ft1.slice(1, 0, 1, 1).squeeze(0);
v = ft1.slice(1, 1, 2, 1).squeeze(0);
gridX = pTorchJitModule->GetGridTensor(0)->to(pTorchJitModule->GetDeviceType()).expand_as(u);
gridY = pTorchJitModule->GetGridTensor(1)->to(pTorchJitModule->GetDeviceType()).expand_as(v);
x = gridX + u;
y = gridY + v;
x = 2 * (x / w - 0.5);
y = 2 * (y / h - 0.5);
grid = torch::stack({ x, y }, 3);
at::Tensor gi1ft1 = torch::grid_sampler(tensor_image2, grid, 0, 0, false);
at::Tensor iy = torch::cat({ tensor_image1, tensor_image2, f01, f10, ft1, ft0, gi1ft1, gi0ft0 }, 1);
at::Tensor io = pScriptModule2->forward({ iy }).toTensor();
at::Tensor ft0f = io.slice(1, 0, 2, 1) + ft0;
at::Tensor ft1f = io.slice(1, 2, 4, 1) + ft1;
at::Tensor vt0 = sigmoid(io.slice(1, 4, 5, 1));
at::Tensor vt1 = 1 - vt0;
u = ft0f.slice(1, 0, 1, 1).squeeze(0);
v = ft0f.slice(1, 1, 2, 1).squeeze(0);
gridX = pTorchJitModule->GetGridTensor(0)->to(pTorchJitModule->GetDeviceType()).expand_as(u);
gridY = pTorchJitModule->GetGridTensor(1)->to(pTorchJitModule->GetDeviceType()).expand_as(v);
x = gridX + u;
y = gridY + v;
x = 2 * (x / w - 0.5);
y = 2 * (y / h - 0.5);
grid = torch::stack({ x, y }, 3);
at::Tensor gi0ft0f = torch::grid_sampler(tensor_image1, grid, 0, 0, false);
u = ft1f.slice(1, 0, 1, 1).squeeze(0);
v = ft1f.slice(1, 1, 2, 1).squeeze(0);
gridX = pTorchJitModule->GetGridTensor(0)->to(pTorchJitModule->GetDeviceType()).expand_as(u);
gridY = pTorchJitModule->GetGridTensor(1)->to(pTorchJitModule->GetDeviceType()).expand_as(v);
x = gridX + u;
y = gridY + v;
x = 2 * (x / w - 0.5);
y = 2 * (y / h - 0.5);
grid = torch::stack({ x, y }, 3);
at::Tensor gi1ft1f = torch::grid_sampler(tensor_image2, grid, 0, 0, false);
co_eff[0] = 1 - t;
co_eff[1] = t;
at::Tensor ft_p = (co_eff[0] * vt0 * gi0ft0f + co_eff[1] * vt1 * gi1ft1f) / (co_eff[0] * vt0 + co_eff[1] * vt1);
CTDTorchJitTensor* pTensor = new CTDTorchJitTensor;
pTensor->SetTensor(ft_p);
两张原始图
效果图
考虑训练速度和显存条件主动降低了U-NET层数和图像位数,实际效果要更好些。