使用NCNN在移动端部署深度学习模型

一、整体流程概览

  1. 训练模型,使用各种你熟悉的框架我用的是pytorch
  2. 将*.pth转换成onnx, 优化onnx模型
  3. 使用转换工具转换成可供ncnn使用的模型
  4. 编译ncnn框架,并编写c++代码调用上一步转换的模型,得到模型的输出结果,封装成可供调用的类
  5. 使用JNIC++调用上一步C++封装的类,提供出接口
  6. 在安卓端编写java代码再次封装一次,供应用层调用

二、将*.pth转换成onnx

使用pytorch自带的torch.onnx即可,需要1.1版本以上,这里有一点需要注意,torch的API有些是onnx不支持的,如果转换的时候报错就把模型里的函数改成onnx支持的吧,有些文章里说这里可以设置opset_version=12来解决,但是这样的话在后面转换到ncnn或者mnn的时候造成转换失败,应该是ncnn还没支持到更高版本的onnx的原因。在最后输出之前有个torch.randn()函数,这里的参数格式是[b,c, w,h]这里也不是随便写的,b固定是1了,你模型的输入通道是多少就写多少,后面的就是模型的输入,这里一旦固定了,后面在第5步的时候c++里的输入也就固定了

convet2onnx.py

# -*- coding:utf-8 -*-
# name: convert_onnx
# author: bqh
# datetime:2020/6/17 10:31
# =========================
import torch
def load_model(model, pretrained_path):
    print('Loading pretrained model from {}'.format(pretrained_path))
    pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)    
    if "state_dict" in pretrained_dict.keys():
        pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
    else:
        pretrained_dict = remove_prefix(pretrained_dict, 'module.')
    check_keys(model, pretrained_dict)
    model.load_state_dict(pretrained_dict, strict=False)
    return model
    
output_onnx = '../weights/output.onnx'
raw_weights = '../weights/model.pth'

# load weight
net = you_net()
net = load_model(net, raw_weights)
net.eval()
print('Finished loading model!')
device = torch.device("cuda")
net = net.to(device)

input_names = ["input0"]
output_names = ["output0"]
inputs = torch.randn(1, 3, 300, 300).to(device)
torch_out = torch.onnx._export(net, inputs, output_onnx, export_params=True, verbose=False,keep_initializers_as_inputs=True, input_names=input_names, output_names=output_names)

安装onnx简化工具

pip3 install onnx-simplifier onnxruntime

简化onnx模型

这一步一定要做,否则后面转onnx的时候会报错

python3 -m onnxsim model.onnx model_sim.onnx

三、编译NCNN框架

主要参考ncnn官网的教程即可,windows下编译同上一篇的MNN的编译都差不多,只有一点需要说明,官网的教程上有vulkan-sdk的安装然后打开-DNCNN_VULKAN=ON编译选项。我一切照做后编译出来的ncnn.lib在运行ncnn::Extractor ex = Net->create_extractor();这个函数后的所有操作之后,返回的时候就报堆栈溢出错误,包括加载官网给出的例子全部报错;后来不cmake的时候这个编译选项不打开编译出来的ncnn.lib就一切正常了。可能是自己的问题,也没去深究。反正能用就OK了。我把编译出来的ncnn.lib ncnn.a和linux下的onnx2ncnn工具都放在了我的网盘里,不想被编译折磨的就直接去下吧。如果编译遇见问题,也可以给我留言,哈哈~
说明:ncnnd.lib是windows下的debug版本,ncnn.lib是release版本,libncnn.a是linux下的库文件,onnx2ncnn是linux下的转换工具。
下载地址:NCNN 提取码:6cuc

四、C++调用和封装

说明

对于vs中lib库和include目录的配置就不赘述了,有不懂的之前的文章有提过,假定工程已经配置完成。大体的调用过程NCNN和MNN都差不多,先加载模型创建一个指向模型的指针,然后创建session、创建用于处理输入的tensor,将input_tensor送入session,运行session,最后得到网络的输出。如果对C++比较熟悉的话,看着官网的教程比葫芦画瓢即可,只有一个地方需要说明就是对输出的获得。先看下我的代码和官网的代码再说为什么
我的输出:

	// run net and get output
	ncnn::Mat out, out1;
	ret = ex.extract("output0", out);
	ex.extract("376", out1);

官网的例子输出:

    ncnn::Mat out;
    ex.extract("detection_out", out);

辣么问题来了,我的"output0"和"376"、官网的“detection_out”都哪里来的?有两个地方可以得到,最简单的方法,使用MNN框架下的转换工具,在转换完成的时候会给出模型的输入和输出名称,直接拷贝即可

>MNNConvert.exe -f ONNX --modelFile model.onnx --MNNModel slime.mnn  --bizCode biz

MNNConverter Version: 0.2.1.5git - MNN @ 2018


Start to Convert Other Model Format To MNN Model...
[17:49:58] :29: ONNX Model ir version: 6
Start to Optimize the MNN Net...
[17:49:58] :20: Inputs: input0
[17:49:58] :37: Outputs: output0, Type = Concat
[17:49:58] :37: Outputs: 376, Type = Softmax
Converted Done!

如果没有MNN的转换工具,在后面加载模型后单步跟一下,在Net = new ncnn::Net()变量中有个blob变量,在内存中查看一下,里面存的有模型的各个层的名称。代码中的img_w,img_h就是在第二步转换的时候你指定的w,h。这里只写了核心调用函数,具体使用时还请自行添加一些辅助函数!

C++代码

detection.h

#pragma once

#include 
#include 
#include 
#include "net.h"
#include 
#include 
#include 
#include 
#include 
#include 
#include "omp.h"


struct bbox {
	float x1;
	float y1;
	float x2;
	float y2;
	float s;
};

struct box {
	float cx;
	float cy;
	float sx;
	float sy;
};

struct ObjectInfo {
	float x1; //bbox的left
	float y1; //bbox的top
	float x2; //bbox的right
	float y2; //bbox的bottom
	float prob; //置信度
};


class ObjectDetection
{
private:
	float _nms = 0.4;
	float _threshold = 0.6;
	const float mean_vals[3] = { 104.f, 117.f, 123.f };
	const float norm_vals[3] = { 1.0 / 104.0, 1.0 / 117.0, 1.0 / 123.0 };
	cv::Mat img;
	ncnn::Net *Net;
	int img_w = 300;
	int img_h = 300;
	int numThread;
	int detect_count = 0;
	static inline bool cmp(bbox a, bbox b);
public:
	ObjectDetection(std::string modelFolder, int num_thread);
	~ObjectDetection();
	int Detect(unsigned char *inputImage, int inputw, int inputh, std::vector<ObjectInfo > &obj);
};

detection.cpp

#include "Detection.h"
#include 

ObjectDetection::ObjectDetection(std::string modelFolder, int num_thread)
{
	Net = new ncnn::Net();
	std::string model_param = modelFolder + "Detect.param";
	std::string model_bin = modelFolder + "Detect.bin";
	int ret = Net->load_param(model_param.c_str());
	ret = Net->load_model(model_bin.c_str());	
	numThread = num_thread;
}

ObjectDetection::~ObjectDetection()
{
	if (Net != nullptr)
	{
		delete Net;
		Net = nullptr;
	}
}

int ObjectDetection::Detect(unsigned char *inputImage, int inputw, int inputh, std::vector<ObjectInfo > &obj)
{
	int ret = -1;	
	ncnn::Mat in = ncnn::Mat::from_pixels_resize(inputImage, ncnn::Mat::PIXEL_BGR, inputw, inputh, img_w, img_h);
	in.substract_mean_normalize(mean_vals, norm_vals);	
	ncnn::Extractor ex = Net->create_extractor();		
	ex.set_light_mode(true);
	ret = ex.input("input0", in);	
	// run net and get output
	ncnn::Mat out, out1;
	// bbox的输出
	ret = ex.extract("output0", out);
	ex.extract("376", out1);

	// get result	
	for (int i = 0; i < out.h; ++i)
	{
		// 得到网络的具体输出
		const float *boxes = out.row(i);
		const float *scores = out1.row(i);
		// 执行你自己的操作
	}
	std::sort(total_box.begin(), total_box.end(), cmp);
	NMS(total_box, _nms);
	return 0;
}

五、 编写JNI C++

在Android Studio中配置NDK,具体配置网上有很多教程我就不啰嗦了,假定android strdio的jni c++环境已经配置完成。源码中的函数名的格式是jni c++要求的,必须这种格式,根据实际情况修改,函数名中的"com_example_demokit_Detection"对应到java的应用中就是"com.example.demokit.Detection"这样就很好理解了。
native-lib.cpp

#include 
#include 

#include "Detection.h"
#include 

extern "C" JNIEXPORT jlong JNICALL
Java_com_example_demokit_Detection_Create(JNIEnv *env, jobject instance, jstring path) {
    char* _path;
    _path = (char*)env->GetStringUTFChars(path,0);
    Detection *phandle = new Detection(_path, 2);
    return (jlong)phandle;
}

extern "C" JNIEXPORT jintArray JNICALL
Java_com_example_demokit_Detection_Detect(JNIEnv *env, jobject instance, jlong handle, jint campos, jint w, jint h, jbyteArray data_) {
    Detection *gp = NULL;
    if (handle)
        gp = (Detection *)handle;
    else
        return nullptr;
    jbyte *data = env->GetByteArrayElements(data_, NULL);

    std::vector<ObjectInfo> objects;
    gp->Detect((unsigned char*)data, w, h, objects);
    env->ReleaseByteArrayElements(data_, data, 0);
    jintArray jarr = env->NewIntArray(objects.size()*15+1);
    jint *arr = env->GetIntArrayElements(jarr, NULL);
    arr[0] = objects.size();
    for (int i = 0; i < objects.size(); i++)
    {
        arr[i*5 + 1] = objects[i].x1;
        arr[i*5 + 2] = objects[i].y1;
        arr[i*5 + 3] = objects[i].x2;
        arr[i*5 + 4] = objects[i].y2;
        arr[i*5 + 5] = objects[i].prob;
    }
    env->ReleaseIntArrayElements(jarr, arr, 0);
    return jarr;
}

六、java调用

package com.example.demokit;

public class Detection {
    static {
        System.loadLibrary("native-lib");
    }
    private long handle;
    public Detection(String path){
        handle = Create(path);
    }
    public int[] Detect(int w, int h, byte[] data){
            return  Detect(handle, w, h, data);
    }
    private native long Create(String path);
    private native int[] Detect(long handle, int w, int h, byte[] data);
}

七、 应用层使用

在应用层就可以直接调用上面的java类啦,搞定~

package com.example.demokit;
import android.graphics.Point;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class DetectTool {
    private Detection mDetection;
    private static final int DATA_LENGTH = 5; // 矩形框坐标2个,每个具有x,y两个值;置信度1个;

    public DetectTool(String dect_model_dir){
        /**
         * @dect_model_dir: 检测模型所在的目录路径
         */
        mDetection = new Detection(dect_model_dir);
    }

    private ObjectInfo ArrayAnalysis(int[] src_array){
        /**
         * 对输入的数组进行解析,返回ObjectInfo对象
         * @src_array: 具有DATA_LENGTH所示结构的数组
         */
        ObjectInfo obj_info = new ObjectInfo();

        Point[] pointFaceBox = new Point[2];
        // face_bbox 坐标
        for(int i = 0; i < 2; i++) {
            Point point = new Point();
            point.x = src_array[2*i];
            point.y = src_array[2*i+1];
            pointFaceBox[i] = point;
        }
        // 置信度
        obj_info.setProb(src_array[4]);
        return obj_info ;
    }

    public List<ObjectInfo> GetObjectInfo(int width, int height, byte[] data){
        /**
         * @width:图片宽度
         * @height:图片高度
         * @data:图片的字节流
         */
        int[] obj= mDetection.Detect(width, height, data);
        List<ObjectInfo> obj_list = new ArrayList<>();
        int obj_count = obj[0];
        for(int i = 0; i < obj_count ; i++){
            int[] obj_array = Arrays.copyOfRange(obj, i*DATA_LENGTH + 1, (i + 1) * DATA_LENGTH+1);
            ObjectInfo obj_info = this.ArrayAnalysis(obj_array);
            obj_list.add(obj_info);
        }
        return obj_list;
    }
}

你可能感兴趣的:(深度学习,c++,ncnn,jnic++,onnx,深度学习)