Torch 和 C++互相调用 pybind11

Torch 和 C++互相调用 pybind11

  • torch 和 C++互相调用
  • 代码例子

torch 和 C++互相调用

需要安装torch即可,在linux环境下实验通过。
torch.utils.cpp_extension 通过pybind11实现C++和python互相通信。
在ninja框架下,构建即时代码(JIT),只需要第一次编译C++

代码例子

展示了python如何和C++后端相互调用和传递列表和torchTensor

CPP.cpp

#include 
#include 
#include 
#include 


// 定义类
//https://pybind11.readthedocs.io/en/latest/advanced/classes.html
struct Pet {
    Pet(const std::string &name) : name(name) { }
    void setName(const std::string &name_) { name = name_; }
    const std::string &getName() const { return name; }

    std::string name;
};
using PetList = std::vector<Pet>;

// 定义函数,并调用python,以引用的方式传参
// https://pybind11.readthedocs.io/en/latest/advanced/cast/stl.html?highlight=STL#making-opaque-types
PYBIND11_MAKE_OPAQUE(std::vector<Pet>)

void addAndprintPet()
{
    PetList petlist;
    petlist.push_back(Pet("CatCpp"));
    py::object addPet=py::module::import("PY").attr("addPet");

    addPet(&petlist);
    
    for (auto pet:petlist)
    {
        std::cout<<"from CPP "<<pet.getName()<<std::endl;
    }
}

//通过返回值传递列表注意一切python返回皆为object,需要强转
//https://pybind11.readthedocs.io/en/latest/advanced/pycpp/object.html#instantiating-compound-python-types-from-c
void printList()
{
    py::list a;
    a.append(123); //python 有的基本都能用,包括模块
    // py::module sys=py::module::import("sys");
    // py::print(sys.attr("path"));
    py::object addNumer=py::module::import("PY").attr("addNumer");

    py::list b = addNumer(a);
    
    for (auto number:b)
    {
        std::cout<<"from CPP "<<number.cast<int>()<<std::endl;
    }
}

torch::Tensor TensorAdd(const torch::Tensor &a,const torch::Tensor &b)
{
    return a+b;
}

void mainFun()
{
    addAndprintPet();

    printList();
}

// 绑定
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    py::class_<PetList>(m, "PetList")
        .def(py::init<>())
        .def("pop_back", &PetList::pop_back)        
        .def("push_back", (void (PetList::*)(const Pet&)) &PetList::push_back)
        .def("__len__", [](const PetList &v) { return v.size(); })
        .def("__iter__", [](PetList &v) {
           return py::make_iterator(v.begin(), v.end());
        }, py::keep_alive<0, 1>());

    py::class_<Pet>(m, "Pet")
        .def(py::init<const std::string &>())
        .def("setName", &Pet::setName)
        .def("getName", &Pet::getName)
        .def("__repr__", [](const Pet& u) { return u.getName(); }); //python print方法调用

    m.def("mainFun", &mainFun, "mainFun");
    m.def("TensorAdd", &TensorAdd, "TensorAdd");
}

PY.py

import os
import torch
from torch.utils.cpp_extension import load

dir = os.path.dirname(os.path.realpath(__file__))
CPP = load(
  name="CPP",
  sources=[os.path.join(dir, "CPP.cpp")],
  verbose=False)


def addPet(petlist):
  # petlist.pop_back()
  for p in petlist:
    print('from PY',p) # petlist 是PetList类型
  petlist.push_back(CPP.Pet('CatPy'))


def addNumer(numlist):
  print('from PY',numlist)
  return numlist+[1234]

if __name__=='__main__':
  # 调用CPP的函数addAndprintPet,printList
  CPP.mainFun()

  # 定义类
  p = CPP.Pet("Molly")
  print(p)
  print(p.getName())
  p.setName("Charly")
  print(p.getName())

  # Tensor加法
  print(CPP.TensorAdd(torch.zeros((3,3)),torch.ones(3,3)))


Torch 和 C++互相调用 pybind11_第1张图片

你可能感兴趣的:(python,c++,混合编程,pytorch)