一文玩转pytorch转onnx-tensorRT ——(1)创建tensorRT层

基本说明

  • 对c++,tensorRT提供了一个纯虚类的接口,通过实现类内的虚函数,就可以在tensorRT中运行自定义层。注意nvinfer1::IPluginV2和nvinfer1::IPluginCreator配合使用;nvinfer1::IPlugin和nvinfer1::IPluginFactory。后者是前者的再一次封装,IPluginV2相对于IPlugin来说,基本上一样,多几个函数而已,差别是tensorRT的版本支持问题。父类选用哪个都是可以的。
  • 对于IPluginV2的文档,官方的主要有两个地方:一是函数的说明;二是函数的形参介绍
  • 示例主要有两个,一是安装是一并安装的例子,像在我电脑上的位置是/usr/src/tensorrt/samples;另外一个是onnx-tensorRT的plugin,但是不是所有的版本都有,要注意对应版本号,比如IN层等,这是链接。

group normal 示例

以上这些资源足够自行DIY了。另放出自己的小demo来说明一点细节问题。下面的demo是标准写法,需要注意的点分开说明:

  • 1.构造函数是调用时分配的参数,和下面的enqueue函数有所区分就好,后者是运行时传入的参数。
  • 2.getWorkspaceSize是这一layer运行时分配的临时变量所占的显存,不过不用自己分配了,需要在这申请。这在我定义的 "…/cuda/groupnorm.cu"会有体现。
  • 3.serialize和另外一个虚构造函数式匹配的,注意这儿有个结构体的数据,所以数据的顺序是不能错的。
  • 4.REGISTER_TENSORRT_PLUGIN()就是将这个函数注册到了tensorRT了,通过给的名字就能找到,估计也是一个字典,后面会提到onnx-tensorRT的parsing就是一个字典。

此外,这个demo也展示了对IPluginFactory的继承实现,不过都是标准写法,也没有什么好说的。

  • 补充说明的一点是IPluginCreator中的两个函数const nvinfer1::PluginFieldCollection *getFieldNames() && nvinfer1::IPluginV2 *createPlugin(),这两个函数没什么必要去实现,回头在“如何测试这个插件层”专门说一下。不过这儿的写法倒是可以参考一下,花了一点时间的。
#ifndef INFER__GN_H
#define INFER__GN_H

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include "../cuda/groupnorm.h"

#define GN_PLUGIN_NAME "group_norm"
#define GN_PLUGIN_VERSION "1"
#define GN_PLUGIN_NAMESPACE ""
#define CHECK_CU(status) { if (status != 0) throw std::runtime_error(__FILE__ +  __LINE__ + std::string{"CUDA Error: "} + std::to_string(status)); }


// Helpers to move data to/from the GPU.
nvinfer1::Weights copyToDevice(const void *hostData, int count) {
    void *deviceData;
    CHECK_CU(cudaMalloc(&deviceData, count * sizeof(float)));
    CHECK_CU(cudaMemcpy(deviceData, hostData, count * sizeof(float), cudaMemcpyHostToDevice));
    return nvinfer1::Weights{nvinfer1::DataType::kFLOAT, deviceData, count};
}

int copyFromDevice(char *hostBuffer, nvinfer1::Weights deviceWeights) {
    *reinterpret_cast<int *>(hostBuffer) = deviceWeights.count;
    CHECK_CU(cudaMemcpy(hostBuffer + sizeof(int), deviceWeights.values, deviceWeights.count * sizeof(float),
                        cudaMemcpyDeviceToHost));
    return sizeof(int) + deviceWeights.count * sizeof(float);
}

template<typename T>
void write(char *&buffer, const T &val) {
    *reinterpret_cast<T *>(buffer) = val;
    buffer += sizeof(T);
}

template<typename T>
void read(const char *&buffer, T &val) {
    val = *reinterpret_cast<const T *>(buffer);
    buffer += sizeof(T);
}

void checkTensorData(int N, const void *inputs, const char *message) {
    const float *B = reinterpret_cast<const float *>(inputs);
    int pl = N * sizeof(float);
    float b[N];
    cudaMemcpy(b, B, pl, cudaMemcpyDeviceToHost);
    std::cout << message << " in " << __FILE__ << "@" << __LINE__ << " :";
    for (int i = 0; i < N; i++)std::cout << b[i] << ',';
    std::cout << std::endl;
}

class GNPlugin : public nvinfer1::IPluginV2 {

public:
    // In this simple case we're going to infer the number of output channels from the bias weights.
    // The knowledge that the kernel weights are weights[0] and the bias weights are weights[1] was
    // divined from the caffe innards
    GNPlugin(const nvinfer1::Weights *weights, int nbWeights, int group, float epsilon) {
        assert(nbWeights == 2);
        mKernelWeights = copyToDevice(weights[0].values, weights[0].count);
        mBiasWeights = copyToDevice(weights[1].values, weights[1].count);
        G = group;
        epsilon_ = epsilon;
    }

    GNPlugin() = delete;

    // Create the plugin at runtime from a byte stream.
    GNPlugin(const void *data, size_t length) {
        const char *d = reinterpret_cast<const char *>(data);
        const char *CHECK_CU = d;
        // Deserialize kernel.
        read(d, C);
        read(d, HxW);
        read(d, G);
        read(d, epsilon_);
        const int kernelCount = reinterpret_cast<const int *>(d)[0];
        mKernelWeights = copyToDevice(d + sizeof(int), kernelCount);
        d += sizeof(int) + mKernelWeights.count * sizeof(float);
        // Deserialize bias.
        const int biasCount = reinterpret_cast<const int *>(d)[0];
        mBiasWeights = copyToDevice(d + sizeof(int), biasCount);
        d += sizeof(int) + mBiasWeights.count * sizeof(float);
        // CHECK_CU that the sizes are what we expected.
        assert(d == CHECK_CU + length);
    }

    virtual int getNbOutputs() const override { return 1; }

    virtual nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs,
                                               int nbInputDims) override {
        //attention: although input should be NCHW, it's CHW actually
        assert(index == 0 && nbInputDims == 1);
        return *inputs;
    }

    virtual int initialize() override { return 0; }

    virtual void terminate() override {}

    virtual size_t getWorkspaceSize(int maxBatchSize) const override {
        static int space = calWorkSpace(maxBatchSize, C, G);
        return space;
    }

    virtual int enqueue(int batchSize, const void *const *inputs, void **outputs,
                        void *workspace, cudaStream_t stream) override {
        return RunOnDeviceWithOrderNCHW(batchSize, C, HxW, G,
                                        reinterpret_cast<const float *>(inputs[0]),
                                        reinterpret_cast<const float *>(mKernelWeights.values),
                                        reinterpret_cast<const float *>(mBiasWeights.values),
                                        reinterpret_cast<float *>(outputs[0]),
                                        epsilon_,
                                        workspace,
                                        getWorkspaceSize(batchSize),
                                        stream);
    }

    // For this sample, we'll only support float32 with NCHW.
    virtual bool supportsFormat(nvinfer1::DataType type, nvinfer1::PluginFormat format) const override {
        return (type == nvinfer1::DataType::kFLOAT && format == nvinfer1::PluginFormat::kNCHW);
    }

    void configureWithFormat(const nvinfer1::Dims *inputDims, int nbInputs,
                             const nvinfer1::Dims *outputDims, int nbOutputs,
                             nvinfer1::DataType type, nvinfer1::PluginFormat format, int maxBatchSize) {
        assert(type == nvinfer1::DataType::kFLOAT);
        assert(format == nvinfer1::PluginFormat::kNCHW);
        assert(mKernelWeights.count == inputDims[0].d[0] && mBiasWeights.count == inputDims[0].d[0]);
        assert(nbOutputs == 1 && outputDims[0].d[0] == outputDims[0].d[0]);
        C = inputDims[0].d[0];
        HxW = inputDims[0].d[1] * inputDims[0].d[2];
    }

    size_t getSerializationSize() const override {
        return sizeof(int) * (2 + 3) + sizeof(float) + 2 * C * sizeof(float);
    }

    void serialize(void *buffer) const override {
        char *d = static_cast<char *>(buffer);
        const char *CHECK_CU = d;
        write(d, C);
        write(d, HxW);
        write(d, G);
        write(d, epsilon_);
        d += copyFromDevice(d, mKernelWeights);
        d += copyFromDevice(d, mBiasWeights);
        assert(d == CHECK_CU + getSerializationSize());
    }

    // Free buffers.
    void destroy() override {
        cudaFree(const_cast<void *>(mKernelWeights.values));
        mKernelWeights.values = nullptr;
        cudaFree(const_cast<void *>(mBiasWeights.values));
        mBiasWeights.values = nullptr;
    }

    const char *getPluginType() const override {
        return GN_PLUGIN_NAME;
    }

    const char *getPluginVersion() const override {
        return GN_PLUGIN_VERSION;
    }

    const char *getPluginNamespace() const override {
        return GN_PLUGIN_NAMESPACE;
    }

    void setPluginNamespace(const char *N) override {}

    IPluginV2 *clone() const override {
        const int nbWeights = 2;
        const nvinfer1::Weights weights[nbWeights] = {mKernelWeights, mBiasWeights};
        return new GNPlugin(weights, nbWeights, G, epsilon_);
    }

private:
    int C, HxW, G;
    float epsilon_;
    nvinfer1::Weights mKernelWeights{nvinfer1::DataType::kFLOAT, nullptr},
            mBiasWeights{nvinfer1::DataType::kFLOAT, nullptr};

};

//class GNPluginFactory : public nvinfer1::IPluginFactory {
//public:
//    bool isPlugin(const char *name) override {
//        printf("gn factory: %s", name);
//        return isPluginExt(name);
//    }
//
//    bool isPluginExt(const char *name) override {
//        printf("gn factory: %s", name);
//        return !strcmp(name, GN_PLUGIN_NAME);
//    }
//
//    // Create a plugin using provided weights.
//    virtual nvinfer1::IPlugin *
//    createPlugin(const char *layerName, const nvinfer1::Weights *weights, int nbWeights) override {
//        static const int GROUP = 32, EPS = 1e-5;
//        assert(isPluginExt(layerName) && nbWeights == 2);
//        assert(mPlugin.get() == nullptr);
//        mPlugin = std::unique_ptr(new GNPlugin(weights, nbWeights, GROUP, EPS));
//        return mPlugin.get();
//    }
//
//    // Create a plugin from serialized data.
//    virtual nvinfer1::IPlugin *
//    createPlugin(const char *layerName, const void *serialData, size_t serialLength) override {
//        assert(isPlugin(layerName));
//        // This will be automatically destroyed when the engine is destroyed.
//        return new GNPlugin{serialData, serialLength};
//    }
//
//    // User application destroys plugin when it is safe to do so.
//    // Should be done after consumers of plugin (like ICudaEngine) are destroyed.
//    void destroyPlugin() { mPlugin.reset(); }
//
//    std::unique_ptr  mPlugin{nullptr};
//};

class GNPluginCreator : public nvinfer1::IPluginCreator {
public:
    GNPluginCreator() {
        std::vector <nvinfer1::PluginField> Attributes;
        // Describe ClipPlugin's required PluginField arguments
        Attributes.emplace_back(nvinfer1::PluginField("count", nullptr, nvinfer1::PluginFieldType::kINT32, 1));
        Attributes.emplace_back(nvinfer1::PluginField("num_groups", nullptr, nvinfer1::PluginFieldType::kINT32, 1));
        Attributes.emplace_back(nvinfer1::PluginField("eps", nullptr, nvinfer1::PluginFieldType::kFLOAT32, 1));
        Attributes.emplace_back(nvinfer1::PluginField("w", nullptr, nvinfer1::PluginFieldType::kFLOAT32, 1));
        Attributes.emplace_back(nvinfer1::PluginField("b", nullptr, nvinfer1::PluginFieldType::kFLOAT32, 1));

        // Fill PluginFieldCollection with PluginField arguments metadata
        mFC.nbFields = Attributes.size();
        mFC.fields = Attributes.data();
    }

    const char *getPluginName() const override {
        return GN_PLUGIN_NAME;
    }

    const char *getPluginVersion() const override {
        return GN_PLUGIN_VERSION;
    }

    const char *getPluginNamespace() const override {
        return GN_PLUGIN_NAMESPACE;
    }

    nvinfer1::IPluginV2 *
    deserializePlugin(const char *name, const void *serialData, size_t serialLength) override {
        printf("name is %s\n", name);
        return new GNPlugin(serialData, serialLength);
    }

    void setPluginNamespace(const char *N) override {}

    const nvinfer1::PluginFieldCollection *getFieldNames() override {
        return &mFC;
    }

    nvinfer1::IPluginV2 *createPlugin(const char *name,
                                      const nvinfer1::PluginFieldCollection *fc) override {
        int count, group;
        float eps;
        const float *kernel, *bias;
        const nvinfer1::PluginField *fields = fc->fields;

        // Parse fields from PluginFieldCollection
        assert(fc->nbFields == 5);
        for (int i = 0; i < fc->nbFields; i++) {
            if (strcmp(fields[i].name, "count") == 0) {
                assert(fields[i].type == nvinfer1::PluginFieldType::kINT32);
                count = *(reinterpret_cast<const int *>(fields[i].data));
            } else if (strcmp(fields[i].name, "num_groups") == 0) {
                assert(fields[i].type == nvinfer1::PluginFieldType::kINT32);
                group = *(static_cast<const int *>(fields[i].data));
            } else if (strcmp(fields[i].name, "eps") == 0) {
                assert(fields[i].type == nvinfer1::PluginFieldType::kFLOAT32);
                eps = *(static_cast<const float *>(fields[i].data));
            } else if (strcmp(fields[i].name, "w") == 0) {
                assert(fields[i].type == nvinfer1::PluginFieldType::kFLOAT32);
                kernel = static_cast<const float *>(fields[i].data);
            } else if (strcmp(fields[i].name, "b") == 0) {
                assert(fields[i].type == nvinfer1::PluginFieldType::kFLOAT32);
                bias = static_cast<const float *>(fields[i].data);
            }
        }

        nvinfer1::Weights weights[] = {nvinfer1::Weights{nvinfer1::DataType::kFLOAT, kernel, count},
                                       nvinfer1::Weights{nvinfer1::DataType::kFLOAT, bias, count}};
        return new GNPlugin(weights, 2, group, eps);
    }

private:
    nvinfer1::PluginFieldCollection mFC;
};

REGISTER_TENSORRT_PLUGIN(GNPluginCreator);

#undef  GN_PLUGIN_NAME
#undef  GN_PLUGIN_VERSION
#undef  GN_PLUGIN_NAMESPACE

#endif //INFER__GN_H

你可能感兴趣的:(深度学习)