PyTorch学习总结(二)——基于torch.utils.ffi的自定义C扩展

PyTorch学习总结(二)——基于torch.utils.ffi的自定义C扩展_第1张图片


步骤一 准备好你的C代码

首先,你写好你的C函数。

接下来你可以找到一个模块的forward和backward函数的实现,其主要实现输入相加的功能。

在你的.c文件中,你可以使用#include #include 指令来分别包含TH及THC。

ffi工具可以确保编译器在build的过程中找到它们。

/* src/my_lib.c */
#include 

int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2,
THFloatTensor *output)
{
    if (!THFloatTensor_isSameSizeAs(input1, input2))
        return 0;
    THFloatTensor_resizeAs(output, input1);
    THFloatTensor_cadd(output, input1, 1.0, input2);
    return 1;
}

int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input)
{
    THFloatTensor_resizeAs(grad_input, grad_output);
    THFloatTensor_fill(grad_input, 1);
    return 1;
}

代码中没有约束条件。如果想要添加约束条件,你得准备一个header文件,它包含了所有希望在python中调用的函数的列表。

然后它会被ffi工具用来生成适当的封装。

/* src/my_lib.h */
int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2, THFloatTensor *output);
int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input);

现在,你需要一个简短的文件来生成(build)你的自定义扩展:

# build.py
from torch.utils.ffi import create_extension
ffi = create_extension(
name='_ext.my_lib',
headers='src/my_lib.h',
sources=['src/my_lib.c'],
with_cuda=False
)
ffi.build()

步骤二 将它包含到你的Python代码中

当你运行完上述指令后,pytorch会创建一个_ext目录,然后将你的my_lib库放进去。

包名可以在最终的模块名前面有任意数量的包(数量也可以等于0).如果build成功,你可以像导入常规的python文件一样导入你的扩展。

定义新的函数:

# functions/add.py
import torch
from torch.autograd import Function
from _ext import my_lib


class MyAddFunction(Function):
    def forward(self, input1, input2):
        output = torch.FloatTensor()
        my_lib.my_lib_add_forward(input1, input2, output)
        return output

    def backward(self, grad_output):
        grad_input = torch.FloatTensor()
        my_lib.my_lib_add_backward(grad_output, grad_input)
        return grad_input

定义新的模块:

# modules/add.py
from torch.nn import Module
from functions.add import MyAddFunction

class MyAddModule(Module):
    def forward(self, input1, input2):
        return MyAddFunction()(input1, input2)

在模块中实现嵌套:

# main.py
import torch
import torch.nn as nn
from torch.autograd import Variable
from modules.add import MyAddModule

class MyNetwork(nn.Module):
    def __init__(self):
        super(MyNetwork, self).__init__()
        self.add = MyAddModule()

    def forward(self, input1, input2):
        return self.add(input1, input2)

model = MyNetwork()
input1, input2 = Variable(torch.randn(5, 5)), Variable(torch.randn(5, 5))
print(model(input1, input2))
print(input1 + input2)

你可能感兴趣的:(PyTorch)