C++调用python训练的pytorch模型(三)----- 实战:封装pytorch模型

文章目录

        • 封装python 模型 SDK
          • 准备好python api函数
          • C++调用python api
          • 生成so文件
        • 调用模型SDK
          • demo
          • makefile
          • 执行demo

封装python 模型 SDK

准备好python api函数

python代码

# webcam_test.py
global g_model
def load_model(wkspace_dir,cfg_file):
    # prepare object that handles inference plus adds predictions on top of image
    global g_model
    print("wkspace_dir: %s" % wkspace_dir)
    print("cfg_file: %s" % cfg_file)
    # os.chdir('/home/bob/wkspace/git/maskrcnn-benchmark/demo')
    os.chdir(wkspace_dir)
    # load config from file and command-line arguments
    # cfg.merge_from_file("r50_1204.yaml")
    cfg.merge_from_file(cfg_file)
    # cfg.merge_from_list(args.opts)
    cfg.freeze()
    coco_demo = COCODemo(
        cfg,
        confidence_threshold=0.7,
        show_mask_heatmaps=False,
        masks_per_dim=2,
        min_image_size=480,
    )
    g_model = coco_demo

def forward(image):
    global g_model
    print('image path is: %s' %image)
    image=cv2.imread(image)
    predictions = g_model.compute_prediction(image)
    predictions = g_model.select_top_predictions(predictions)
    scores = predictions.get_field("scores").tolist()
    labels = predictions.get_field("labels").tolist()
    labels = [g_model.CATEGORIES[i] for i in labels]
    boxes = predictions.bbox
    template = "{}: {:.2f}"
    for box, score, label in zip(boxes, scores, labels):
        x, y = box[:2]
        s = template.format(label, score)
        # cv2.putText(image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, .5, (255, 255, 255), 1)
        if label is None:
            return ""
        label = g_model.CATEGORIES.index(label)
        bbox_str = ' '.join(str('%.2f,' % i) for i in box)
        bbox_str = bbox_str.rstrip(',')
        result = str(label) + ',' + str('%.2f' % score) + ',' + bbox_str
        print(result)
        return result
    return ""

if __name__ == "__main__":
    # main()
    load_model()
    forward('/home/bob/wkspace/git/maskrcnn-benchmark/demo/183101.jpg')
C++调用python api

hands_detect.h

#include 
#include 
#include 
#include 

using namespace std;

namespace handsDetect{
    /**
    *  功能描述:网络初始化,加载模型。
    *  @param wkspace_path     :python工作目录必须设置maskrcnn/demo目录;
    *  @param cfg_file  :配置文件的路径,可以相对demo的路径;
    *  @return  0-成功;-1-失败;
    */
int Initialize(const char* wkspace_path,const char* cfg_file);
	 /**
	    *  功能描述:执行网络前向预测。
	    *  @param image_file    :网络输入一张图片;
	    *  @param output        :网络输出,长度为6的数组:依次保持 [class,score,bbox[4]]
	    *  @return  0-成功;-1-失败;
	   */
int  Forward(const char* image_file, vector &output);
    /**
    *  功能描述:释放内存
    *  @return  0-成功;-1-失败;
    */
void Uninitialize();
}

hands_detect.cpp

#include "hands_detect.h"

PyObject * g_pModule = NULL;
PyObject * g_pFunc_init = NULL;
PyObject * g_pFunc_forward = NULL;

namespace handsDetect{
int Initialize(const char* wkspace_path,const char* cfg_file)
{
	if(wkspace_path == nullptr || cfg_file == nullptr){
		cout<<"please check wkspace path or cfg file path"<& v, const string& c)
{
    string::size_type pos1, pos2;
    pos2 = s.find(c);
    pos1 = 0;
    while(string::npos != pos2)
    {
        v.push_back(s.substr(pos1, pos2-pos1));

        pos1 = pos2 + c.size();
        pos2 = s.find(c, pos1);
    }
    if(pos1 != s.length())
        v.push_back(s.substr(pos1));
}

int string_to_data(char* src_str, vector &output)
{
    //char *src_str= "1.1,2.2,3.3,4.44,5.5555,6.666";
    string s = src_str;
    vector v;
    split_string(s, v,","); //可按多个字符来分隔;
    for(vector::size_type i = 0; i != v.size(); ++i)
    {
        //cout <<"v:"<< v[i] << endl;
//        cout <<"o:"<< atof(v[i].c_str()) << endl;
    	output.push_back(atof(v[i].c_str()));
    }

    return 0;
}
int  Forward(const char* image_file, vector &output)
{
	PyObject *args = Py_BuildValue("(s)",image_file);
	PyObject *pRet = PyObject_CallObject(g_pFunc_forward,args);
	if(pRet == nullptr)
	{
		cout<<"Error: python pFunc_forward pRet is null!"<
生成so文件

bash命令行

g++ hands_detect.cpp  -fPIC -shared -o libhandsdetect.so -std=c++11 \
-I/home/bob/anaconda2/envs/benchmark_py36/include/python3.6m/  \
-L/home/bob/anaconda2/envs/benchmark_py36/lib/ -lpytho3.6m

makefile

# source object target
SRCS   := hands_detect.cpp
OBJS   := hands_detect.o
TARGET := libhandsdetect.so

# compile and lib parameter
CC      := g++
CFLAGS  := -Wall -g
LIBS    := -lpython3.6m
LIB_DIR := -L/home/bob/anaconda2/envs/benchmark_py36/lib/
DEFINES :=
INCLUDE := -I/home/bob/anaconda2/envs/benchmark_py36/include/python3.6m/

all:
	$(CC) $(INCLUDE) $(CFLAGS) -fPIC -shared -std=c++11 $(SRCS) -o $(TARGET) $(LIB_DIR) $(LIBS)
	
# all:
# $(CC) -o $(TARGET) $(SOURCE)

clean:
	rm -fr *.o $(TARGET)

调用模型SDK

demo
#include "hands_detect.h"
#include 
using namespace std;
namespace hd=handsDetect;
int main()
{
	vector output;
	if(hd::Initialize("/home/bob/wkspace/git/maskrcnn-benchmark/demo/","r50_1204.yaml") < 0){
		cout<< "hands detect initialize is failed"<::size_type i = 0; i != output.size(); ++i)
    	{
        	cout <
makefile
# source object target
SOURCE := demo.cpp
OBJS   := demo.o
TARGET := demo

# compile and lib parameter
CC      := g++
LIBS    := -lhandsdetect
LDFLAGS := -L. \
		-L./lib \
	   -L/home/bob/anaconda2/envs/benchmark_py36/lib/
DEFINES :=
INCLUDE := -I/home/bob/anaconda2/envs/benchmark_py36/include/python3.6m/ \
	   -I./lib
CFLAGS  := -Wl,-rpath=/home/bob/anaconda2/envs/benchmark_py36/lib/ -Wl,-rpath=./lib
CXXFLAGS:=

# link
$(TARGET):$(OBJS)
	$(CC) -o $@ $^ $(LDFLAGS) $(CFLAGS)  $(LIBS)

# compile
$(OBJS):$(SOURCE)
	$(CC) $(INCLUDE) -g -c $^ -o $@

# all:
# $(CC) -o $(TARGET) $(SOURCE)

clean:
	rm -fr *.o $(TARGET)
执行demo
cd lib
make clean
make
cd ..
make clean
make
./demo

你可能感兴趣的:(C/C++,机器学习,linux)