使用pytorch自带的torch.onnx即可,需要1.1版本以上,这里有一点需要注意,torch的API有些是onnx不支持的,如果转换的时候报错就把模型里的函数改成onnx支持的吧,有些文章里说这里可以设置opset_version=12来解决,但是这样的话在后面转换到ncnn或者mnn的时候造成转换失败,应该是ncnn还没支持到更高版本的onnx的原因。在最后输出之前有个torch.randn()函数,这里的参数格式是[b,c, w,h]这里也不是随便写的,b固定是1了,你模型的输入通道是多少就写多少,后面的就是模型的输入,这里一旦固定了,后面在第5步的时候c++里的输入也就固定了
# -*- coding:utf-8 -*-
# name: convert_onnx
# author: bqh
# datetime:2020/6/17 10:31
# =========================
import torch
def load_model(model, pretrained_path):
print('Loading pretrained model from {}'.format(pretrained_path))
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
if "state_dict" in pretrained_dict.keys():
pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
else:
pretrained_dict = remove_prefix(pretrained_dict, 'module.')
check_keys(model, pretrained_dict)
model.load_state_dict(pretrained_dict, strict=False)
return model
output_onnx = '../weights/output.onnx'
raw_weights = '../weights/model.pth'
# load weight
net = you_net()
net = load_model(net, raw_weights)
net.eval()
print('Finished loading model!')
device = torch.device("cuda")
net = net.to(device)
input_names = ["input0"]
output_names = ["output0"]
inputs = torch.randn(1, 3, 300, 300).to(device)
torch_out = torch.onnx._export(net, inputs, output_onnx, export_params=True, verbose=False,keep_initializers_as_inputs=True, input_names=input_names, output_names=output_names)
pip3 install onnx-simplifier onnxruntime
这一步一定要做,否则后面转onnx的时候会报错
python3 -m onnxsim model.onnx model_sim.onnx
主要参考ncnn官网的教程即可,windows下编译同上一篇的MNN的编译都差不多,只有一点需要说明,官网的教程上有vulkan-sdk的安装然后打开-DNCNN_VULKAN=ON编译选项。我一切照做后编译出来的ncnn.lib在运行ncnn::Extractor ex = Net->create_extractor();这个函数后的所有操作之后,返回的时候就报堆栈溢出错误,包括加载官网给出的例子全部报错;后来不cmake的时候这个编译选项不打开编译出来的ncnn.lib就一切正常了。可能是自己的问题,也没去深究。反正能用就OK了。我把编译出来的ncnn.lib ncnn.a和linux下的onnx2ncnn工具都放在了我的网盘里,不想被编译折磨的就直接去下吧。如果编译遇见问题,也可以给我留言,哈哈~
说明:ncnnd.lib是windows下的debug版本,ncnn.lib是release版本,libncnn.a是linux下的库文件,onnx2ncnn是linux下的转换工具。
下载地址:NCNN 提取码:6cuc
对于vs中lib库和include目录的配置就不赘述了,有不懂的之前的文章有提过,假定工程已经配置完成。大体的调用过程NCNN和MNN都差不多,先加载模型创建一个指向模型的指针,然后创建session、创建用于处理输入的tensor,将input_tensor送入session,运行session,最后得到网络的输出。如果对C++比较熟悉的话,看着官网的教程比葫芦画瓢即可,只有一个地方需要说明就是对输出的获得。先看下我的代码和官网的代码再说为什么
我的输出:
// run net and get output
ncnn::Mat out, out1;
ret = ex.extract("output0", out);
ex.extract("376", out1);
官网的例子输出:
ncnn::Mat out;
ex.extract("detection_out", out);
辣么问题来了,我的"output0"和"376"、官网的“detection_out”都哪里来的?有两个地方可以得到,最简单的方法,使用MNN框架下的转换工具,在转换完成的时候会给出模型的输入和输出名称,直接拷贝即可
>MNNConvert.exe -f ONNX --modelFile model.onnx --MNNModel slime.mnn --bizCode biz
MNNConverter Version: 0.2.1.5git - MNN @ 2018
Start to Convert Other Model Format To MNN Model...
[17:49:58] :29: ONNX Model ir version: 6
Start to Optimize the MNN Net...
[17:49:58] :20: Inputs: input0
[17:49:58] :37: Outputs: output0, Type = Concat
[17:49:58] :37: Outputs: 376, Type = Softmax
Converted Done!
如果没有MNN的转换工具,在后面加载模型后单步跟一下,在Net = new ncnn::Net()变量中有个blob变量,在内存中查看一下,里面存的有模型的各个层的名称。代码中的img_w,img_h就是在第二步转换的时候你指定的w,h。这里只写了核心调用函数,具体使用时还请自行添加一些辅助函数!
detection.h
#pragma once
#include
#include
#include
#include "net.h"
#include
#include
#include
#include
#include
#include
#include "omp.h"
struct bbox {
float x1;
float y1;
float x2;
float y2;
float s;
};
struct box {
float cx;
float cy;
float sx;
float sy;
};
struct ObjectInfo {
float x1; //bbox的left
float y1; //bbox的top
float x2; //bbox的right
float y2; //bbox的bottom
float prob; //置信度
};
class ObjectDetection
{
private:
float _nms = 0.4;
float _threshold = 0.6;
const float mean_vals[3] = { 104.f, 117.f, 123.f };
const float norm_vals[3] = { 1.0 / 104.0, 1.0 / 117.0, 1.0 / 123.0 };
cv::Mat img;
ncnn::Net *Net;
int img_w = 300;
int img_h = 300;
int numThread;
int detect_count = 0;
static inline bool cmp(bbox a, bbox b);
public:
ObjectDetection(std::string modelFolder, int num_thread);
~ObjectDetection();
int Detect(unsigned char *inputImage, int inputw, int inputh, std::vector<ObjectInfo > &obj);
};
detection.cpp
#include "Detection.h"
#include
ObjectDetection::ObjectDetection(std::string modelFolder, int num_thread)
{
Net = new ncnn::Net();
std::string model_param = modelFolder + "Detect.param";
std::string model_bin = modelFolder + "Detect.bin";
int ret = Net->load_param(model_param.c_str());
ret = Net->load_model(model_bin.c_str());
numThread = num_thread;
}
ObjectDetection::~ObjectDetection()
{
if (Net != nullptr)
{
delete Net;
Net = nullptr;
}
}
int ObjectDetection::Detect(unsigned char *inputImage, int inputw, int inputh, std::vector<ObjectInfo > &obj)
{
int ret = -1;
ncnn::Mat in = ncnn::Mat::from_pixels_resize(inputImage, ncnn::Mat::PIXEL_BGR, inputw, inputh, img_w, img_h);
in.substract_mean_normalize(mean_vals, norm_vals);
ncnn::Extractor ex = Net->create_extractor();
ex.set_light_mode(true);
ret = ex.input("input0", in);
// run net and get output
ncnn::Mat out, out1;
// bbox的输出
ret = ex.extract("output0", out);
ex.extract("376", out1);
// get result
for (int i = 0; i < out.h; ++i)
{
// 得到网络的具体输出
const float *boxes = out.row(i);
const float *scores = out1.row(i);
// 执行你自己的操作
}
std::sort(total_box.begin(), total_box.end(), cmp);
NMS(total_box, _nms);
return 0;
}
在Android Studio中配置NDK,具体配置网上有很多教程我就不啰嗦了,假定android strdio的jni c++环境已经配置完成。源码中的函数名的格式是jni c++要求的,必须这种格式,根据实际情况修改,函数名中的"com_example_demokit_Detection"对应到java的应用中就是"com.example.demokit.Detection"这样就很好理解了。
native-lib.cpp
#include
#include
#include "Detection.h"
#include
extern "C" JNIEXPORT jlong JNICALL
Java_com_example_demokit_Detection_Create(JNIEnv *env, jobject instance, jstring path) {
char* _path;
_path = (char*)env->GetStringUTFChars(path,0);
Detection *phandle = new Detection(_path, 2);
return (jlong)phandle;
}
extern "C" JNIEXPORT jintArray JNICALL
Java_com_example_demokit_Detection_Detect(JNIEnv *env, jobject instance, jlong handle, jint campos, jint w, jint h, jbyteArray data_) {
Detection *gp = NULL;
if (handle)
gp = (Detection *)handle;
else
return nullptr;
jbyte *data = env->GetByteArrayElements(data_, NULL);
std::vector<ObjectInfo> objects;
gp->Detect((unsigned char*)data, w, h, objects);
env->ReleaseByteArrayElements(data_, data, 0);
jintArray jarr = env->NewIntArray(objects.size()*15+1);
jint *arr = env->GetIntArrayElements(jarr, NULL);
arr[0] = objects.size();
for (int i = 0; i < objects.size(); i++)
{
arr[i*5 + 1] = objects[i].x1;
arr[i*5 + 2] = objects[i].y1;
arr[i*5 + 3] = objects[i].x2;
arr[i*5 + 4] = objects[i].y2;
arr[i*5 + 5] = objects[i].prob;
}
env->ReleaseIntArrayElements(jarr, arr, 0);
return jarr;
}
package com.example.demokit;
public class Detection {
static {
System.loadLibrary("native-lib");
}
private long handle;
public Detection(String path){
handle = Create(path);
}
public int[] Detect(int w, int h, byte[] data){
return Detect(handle, w, h, data);
}
private native long Create(String path);
private native int[] Detect(long handle, int w, int h, byte[] data);
}
在应用层就可以直接调用上面的java类啦,搞定~
package com.example.demokit;
import android.graphics.Point;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class DetectTool {
private Detection mDetection;
private static final int DATA_LENGTH = 5; // 矩形框坐标2个,每个具有x,y两个值;置信度1个;
public DetectTool(String dect_model_dir){
/**
* @dect_model_dir: 检测模型所在的目录路径
*/
mDetection = new Detection(dect_model_dir);
}
private ObjectInfo ArrayAnalysis(int[] src_array){
/**
* 对输入的数组进行解析,返回ObjectInfo对象
* @src_array: 具有DATA_LENGTH所示结构的数组
*/
ObjectInfo obj_info = new ObjectInfo();
Point[] pointFaceBox = new Point[2];
// face_bbox 坐标
for(int i = 0; i < 2; i++) {
Point point = new Point();
point.x = src_array[2*i];
point.y = src_array[2*i+1];
pointFaceBox[i] = point;
}
// 置信度
obj_info.setProb(src_array[4]);
return obj_info ;
}
public List<ObjectInfo> GetObjectInfo(int width, int height, byte[] data){
/**
* @width:图片宽度
* @height:图片高度
* @data:图片的字节流
*/
int[] obj= mDetection.Detect(width, height, data);
List<ObjectInfo> obj_list = new ArrayList<>();
int obj_count = obj[0];
for(int i = 0; i < obj_count ; i++){
int[] obj_array = Arrays.copyOfRange(obj, i*DATA_LENGTH + 1, (i + 1) * DATA_LENGTH+1);
ObjectInfo obj_info = this.ArrayAnalysis(obj_array);
obj_list.add(obj_info);
}
return obj_list;
}
}