pytorch添加C++拓展简单实战编写及基本功能测试

食用目录

    • 准备实验环境
    • 模块创建与使用基本流程
      • 第一步:编写C++内容
      • 第二步:编写setup.py
      • 第三步:命令行运行
      • 模块使用

准备实验环境

pytorch添加C++拓展简单实战编写及基本功能测试_第1张图片

模块创建与使用基本流程

  1. 利用C++写好自定义层发功能,主要包括前向传播和方向传播,以及pybind11的内容。
  2. 写好setup.py脚本, 并利用python提供的setuptools来编译并加载C++代码。
  3. 编译安装,在python中调用C++扩展接口
  4. 编程使用拓展

第一步:编写C++内容

test.h
注意这里调用了一个非常重要的文件
pytorch添加C++拓展简单实战编写及基本功能测试_第2张图片
test.cpp
pytorch添加C++拓展简单实战编写及基本功能测试_第3张图片
pytorch添加C++拓展简单实战编写及基本功能测试_第4张图片

第二步:编写setup.py

setup.py
pytorch添加C++拓展简单实战编写及基本功能测试_第5张图片
pytorch添加C++拓展简单实战编写及基本功能测试_第6张图片

第三步:命令行运行

在setup.py所在的文件夹下运行命令:
python setup.py install,
可以看到一堆输出,该 C++ 模块会被安装在 python 的 site-packages 中。

pytorch添加C++拓展简单实战编写及基本功能测试_第7张图片
稍稍等待后,就能在anaconda3/lib/site-packages文件路径下
发现名为test_cpp-0.1***.egg的文件夹,
在IDE的Project Interpreter中也可以找到对应的package
一般来说,在IDE中可以显示对应的自建包,就说明以上安装步骤没有出问题
pytorch添加C++拓展简单实战编写及基本功能测试_第8张图片

踩坑预警:
如果电脑上装有多个版本的python或者不同环境的python解释器,请确保测试代码运行的python环境和拓展包安装的python环境是同一个!
比如我电脑上有3.7(anaconda)和3.8两个版本的python。我把C++拓展写在了3.7版本的anaconda环境下,但是用3.8版本的python环境是无法调用自编C++拓展的。解决方法是在3.8版本的python库中重新install。

模块使用

test.py
pytorch添加C++拓展简单实战编写及基本功能测试_第9张图片
test2.py
pytorch添加C++拓展简单实战编写及基本功能测试_第10张图片
test2.py运行结果
pytorch添加C++拓展简单实战编写及基本功能测试_第11张图片

运行结果符合预期。说明我们在pytorch的extension.h的帮助,成功在C++层面上定义了Z=2*X+Y的简单网络层,并定义其前向传播与反向传播功能后,像调用torch库一样调用自己编写的C++拓展。

你可能感兴趣的:(知识笔记,踩坑记录)