#include
#include
#include
#include
using namespace cv;
void PreProcess(const cv::Mat &image, cv::Mat &image_blob)
{
int scale = cv::max(image.rows, image.cols);
/*cv::Mat input = cv::Mat::zeros(scale, scale, image.type());
cv::Rect roi(0,0, image.cols, image.rows);*/
cv::Mat input;
image.copyTo(input);
cv::Mat inputs;
cv::resize(input, inputs, cv::Size(256, 256));
inputs.convertTo(inputs, CV_32F);
//std::cout << "aaa"< channels, channel_p;
split(inputs, channels);
cv::Mat R, G, B;
B = channels.at(0);
G = channels.at(1);
R = channels.at(2);
B = (B / 255. - 0.406) / 0.225;
G = (G / 255. - 0.456) / 0.224;
R = (R / 255. - 0.485) / 0.229;
channel_p.push_back(R);
channel_p.push_back(G);
channel_p.push_back(B);
cv::Mat outt = cv::Mat(inputs.rows, inputs.cols, CV_32F);
cv::merge(channel_p, outt);
image_blob = outt;
}
int main(int argc, char* argv[]) {
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test");
Ort::SessionOptions session_options;
session_options.SetIntraOpNumThreads(1);
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
#ifdef _WIN32
//erfnet 模型
const wchar_t* model_path = L"C:\\Users\\Desktop\\switch.onnx";
#else
const char* model_path = "model.onnx";
#endif
Ort::Session session(env, model_path, session_options);
Ort::AllocatorWithDefaultOptions allocator;
//获取输入name
const char* inputName = session.GetInputName(0, allocator);
//std::cout <<"xxxx:"<< inputName << std::endl;
//获取输出name
const char* outputName = session.GetOutputName(0, allocator);
//std::cout << "Output Name: " << outputName << std::endl;
查看模型定义的输入输出数
//size_t num_input_nodes = session.GetInputCount();
//size_t num_out_nodes = session.GetOutputCount();
//std::cout << num_input_nodes << " " << num_out_nodes << std::endl;
std::vector input_node_names{ inputName };
std::vector output_node_names{ outputName };
std::vector inputTensors;
std::vector outputTensors;
//设置输入维度
std::vector input_node_dims = {1,3,256,256};//(bitch_size,channel,height,width)
size_t input_tensor_size = 256*256*3*1;
std::vector input_tensor_values(input_tensor_size);
//加载图片
Mat img = imread("C:\\Users\\Desktop\\onnx_test\\roi.jpg");
//预处理
Mat dst(256, 256, CV_8UC3);
Mat dst2,dst3;
resize(img, dst, Size(256, 256));
//转换为float型
dst.convertTo(dst2, CV_32F);
//减均值除方差,注意B、R通道交换问题
PreProcess(dst2,dst3);
//将mat数据转成vector存储形式
input_tensor_values.assign(dst3.begin(),dst3.end());
//创建输入tensor
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
Ort::Value input_tensor = Ort::Value::CreateTensor(memory_info, input_tensor_values.data(), input_tensor_size, input_node_dims.data(), input_node_dims.size());
assert(input_tensor.IsTensor());
inputTensors.push_back(std::move(input_tensor));
//设置输出维度
std::vector out_node_dims = { 1,9,256,256 };//9=8+1:类+背景
size_t out_tensor_size = 256 * 256 * 9*1;
std::vector out_tensor_values(out_tensor_size);
//创建输出tensor
Ort::Value out_tensor = Ort::Value::CreateTensor(memory_info, out_tensor_values.data(), out_tensor_size, out_node_dims.data(), out_node_dims.size());
assert(out_tensor.IsTensor());
outputTensors.push_back(out_tensor);
//推理
session.Run(Ort::RunOptions{nullptr}, input_node_names.data(), inputTensors.data(), 1,output_node_names.data(), outputTensors.data(), 1);
//解码推理数据
cv::Mat out0 = Mat::zeros(256, 256, CV_8UC1);
cv::Mat out1 = Mat::zeros(256, 256, CV_8UC1);
cv::Mat out2 = Mat::zeros(256, 256, CV_8UC1);
cv::Mat out3 = Mat::zeros(256, 256, CV_8UC1);
cv::Mat out4 = Mat::zeros(256, 256, CV_8UC1);
cv::Mat out5 = Mat::zeros(256, 256, CV_8UC1);
cv::Mat out6 = Mat::zeros(256, 256, CV_8UC1);
cv::Mat out7 = Mat::zeros(256, 256, CV_8UC1);
cv::Mat out8 = Mat::zeros(256, 256, CV_8UC1);
for (int i = 0; i < 256; i++)
{
for (int j = 0; j < 256; j++)
{
float a0 = out_tensor_values[0 * 256 * 256 + i * 256 + j];
float a1 = out_tensor_values[1 * 256 * 256 + i * 256 + j];
float a2 = out_tensor_values[2 * 256 * 256 + i * 256 + j];
float a3 = out_tensor_values[3 * 256 * 256 + i * 256 + j];
float a4 = out_tensor_values[4 * 256 * 256 + i * 256 + j];
float a5 = out_tensor_values[5 * 256 * 256 + i * 256 + j];
float a6 = out_tensor_values[6 * 256 * 256 + i * 256 + j];
float a7 = out_tensor_values[7 * 256 * 256 + i * 256 + j];
float a8 = out_tensor_values[8 * 256 * 256 + i * 256 + j];
if (max(max(max(max(max(max(max(max(a0, a1), a2), a3), a4), a5), a6), a7), a8) == a0) out0.at(i, j) = 255;
if (max(max(max(max(max(max(max(max(a0, a1), a2), a3), a4), a5), a6), a7), a8) == a1) out1.at(i, j) = 255;
if (max(max(max(max(max(max(max(max(a0, a1), a2), a3), a4), a5), a6), a7), a8) == a2) out2.at(i, j) = 255;
if (max(max(max(max(max(max(max(max(a0, a1), a2), a3), a4), a5), a6), a7), a8) == a3) out3.at(i, j) = 255;
if (max(max(max(max(max(max(max(max(a0, a1), a2), a3), a4), a5), a6), a7), a8) == a4) out4.at(i, j) = 255;
if (max(max(max(max(max(max(max(max(a0, a1), a2), a3), a4), a5), a6), a7), a8) == a5) out5.at(i, j) = 255;
if (max(max(max(max(max(max(max(max(a0, a1), a2), a3), a4), a5), a6), a7), a8) == a6) out6.at(i, j) = 255;
if (max(max(max(max(max(max(max(max(a0, a1), a2), a3), a4), a5), a6), a7), a8) == a7) out7.at(i, j) = 255;
if (max(max(max(max(max(max(max(max(a0, a1), a2), a3), a4), a5), a6), a7), a8) == a8) out8.at(i, j) = 255;
}
}
//保存每个分割类别的mask
imwrite("C:\\Users\\Desktop\\roi_0.jpg", out0);
imwrite("C:\\Users\\Desktop\\roi_1.jpg", out1);
imwrite("C:\\Users\\Desktop\\roi_2.jpg", out2);
imwrite("C:\\Users\\Desktop\\roi_3.jpg", out3);
imwrite("C:\\Users\\Desktop\\roi_4.jpg", out4);
imwrite("C:\\Users\\Desktop\\roi_5.jpg", out5);
imwrite("C:\\Users\\Desktop\\roi_6.jpg", out6);
imwrite("C:\\Users\\Desktop\\roi_7.jpg", out7);
imwrite("C:\\Users\\Desktop\\roi_8.jpg", out8);
return 0;
}