C++ 调用 python深度学习脚本 进行图像分类

C++ 调用 python深度学习脚本 进行图像分类

Python Hello.py 文件

import torch
import PIL
from PIL import Image
import numpy as np
import cv2
import torch
import torchvision.models as models
import torch.nn as nn
import torchvision.transforms as transforms
class Hello:
    def __init__(self, x):
        self.a = x
        print(x)
    def print(self, x=None):
        print(x)
def xprint():
    print("hello world")
def imshow(x):
    a = x[:, 0:len(x[0] - 2):3]
    b = x[:, 1:len(x[0] - 2):3]
    c = x[:, 2:len(x[0] - 2):3]
    print("-------")
    print(a.shape, b.shape, c.shape)
    a = a[:, :, None]
    b = b[:, :, None]
    c = c[:, :, None]
    m = np.concatenate((a, b, c), axis=2)
    print(m.shape)
    rgbImg = cv2.cvtColor(m, cv2.COLOR_BGR2GRAY)
    print(rgbImg.shape)
    cv2.imshow("test", rgbImg)
    cv2.waitKey(0)
    return rgbImg.shape[0]
class Network:
    def __init__(self):
        print("ddd")
        self.model = None
        self.transform =  transforms.Compose([
                transforms.Resize(size=(224, 224)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]
        )
        self.initmodel()
    def initmodel(self):
        model_ft = models.vgg16_bn(pretrained=False)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, 3)#修改原有模型结构
        #self.model = model_ft.load('pill3allkind.pth')
        tempnet = torch.load('pill3allkind.pth')#加载自己训练的模型参数
        model_ft.load_state_dict(tempnet)
        self.model = model_ft
    def testImage(self,x):
        print("testImage")
        a = x[:, 0:len(x[0] - 2):3]
        b = x[:, 1:len(x[0] - 2):3]
        c = x[:, 2:len(x[0] - 2):3]
        a = a[:, :, None]
        b = b[:, :, None]
        c = c[:, :, None]
        img = np.concatenate((a, b, c), axis=2)
        img = np.array(img,dtype='uint8')#修改数据type不然在transform时会有错误
        #img = np.transpose(m,(2,0,1))
        print(img.dtype)
        img = Image.fromarray(img)
        img = self.transform(img)
        print(img.shape)
        img = img.unsqueeze(0)
        label = self.model(img)
        print(label)
    def classlifyImage(self):
        self.initmodel()
        img = Image.open("content.jpg")
        im = np.asarray(img)
        print(im.shape)
        print(im.dtype)
        img = Image.fromarray(im)
        img = self.transform(img)
        img = torch.unsqueeze(img,0)
        label = self.model(img)
        print(label)
if __name__ == "__main__":
    net = Network()
    #net.classlifyImage()
    img = np.random.randint(0, 255, (319, 510, 3))
    net.testImage(img)

包含两个类,主要用于说明,C++调用python 函数与python类内函数

C++ Main文件

C++ main.cpp
// ConsoleApplication1.cpp : 定义控制台应用程序的入口点。
//
#include "stdafx.h"
#include 
#include 
#include    //路径在安装的python numpy包下
#include 
#include 
#include 
#include 
#include 
#include 
using namespace std;
int _tmain(int argc, _TCHAR* argv[])
{
Py_Initialize();
import_array();//如果要传送图片参数到python 脚本,务必加上该句,时numpy环境
// 将当前目录加入sys.path
PyRun_SimpleString("import sys");
PyRun_SimpleString("sys.path.append('./')");
// 导入hello.py模块
PyObject *pmodule = PyImport_ImportModule("hello");
// 获得函数xprint对象,并调用,输出“hello world\n”
PyObject *pfunc = PyObject_GetAttrString(pmodule, "xprint");
PyObject_CallFunction(pfunc, NULL);
// 获得类Hello并生成实例pinstance,并调用print成员函数,输出“5 6\n”
PyObject *pclass = PyObject_GetAttrString(pmodule, "Hello");
PyObject *arg = Py_BuildValue("(i)", 5);
PyObject *pinstance = PyObject_Call(pclass, arg, NULL); //实例化一个Hello类
PyObject_CallMethod(pinstance, "print", "i", 6);
cv::Mat img = cv::imread("content.jpg", cv::IMREAD_COLOR);
int m, n;
n = img.cols * 3;
m = img.rows;
unsigned char *data = (unsigned  char*)malloc(sizeof(unsigned char)* m * n);
int p = 0;
for (int i = 0; i < m; i++)
{
for (int j = 0; j < n; j++)
{
data[p] = img.at<unsigned char>(i, j);
p++;
}
}
npy_intp Dims[2] = { m, n }; //给定维度信息
PyObject*PyArray = PyArray_SimpleNewFromData(2, Dims, NPY_UBYTE, data);
PyObject*ArgArray = PyTuple_New(1);
PyObject*ArgArray1 = PyTuple_New(0);
PyTuple_SetItem(ArgArray, 0, PyArray);
PyObject *pDict = PyModule_GetDict(pmodule);
PyObject *pclass1 = PyObject_GetAttrString(pmodule, "Network");
PyObject* pinstance1 = PyObject_Call(pclass1, ArgArray1,NULL);//实例化一个Network类,该类没有参数传入,但要传一个空TupleArray
PyObject* result =PyObject_CallMethod(pinstance1, "testImage","O",ArgArray );
PyObject *pDict = PyModule_GetDict(pmodule);
PyObject*ArgArray2 = PyTuple_New(1);
PyTuple_SetItem(ArgArray2, 0, PyArray);
PyObject*pFuncFive = PyDict_GetItemString(pDict, "imshow");
    PyObject_CallObject(pFuncFive, ArgArray);
PyObject* pReturn = PyObject_CallObject(pFuncFive, ArgArray);
int result;
PyArg_Parse(pReturn, "i", &result);
//可以将result强制转换为结构体
//“i” 改为“s”
Py_DECREF(ArgArray);
Py_DECREF(pmodule);
//std::cout << "sss" << result;
Py_Finalize();
system("pause");
return 0;
}

在上面代码中已经注释的比较清楚,上面代码注释已经比较清晰。

你可能感兴趣的:(编程记录)