pytorch 模型转到c++ torch模型 CenterNet为例

https://pytorch.apachecn.org/docs/1.2/advanced/cpp_export.html

前期:我们拿到的centernet有带DCN版本的,但是CenterNet源码自带编译的,所以先摒弃它,还没尝试如何转到torch里面,使用dlav032模型先。

1使用下面的方法,注意目前只支持较新版本的pytorch,至少0.4是不可以的;

以下代码我是直接写在/src/lib/detectors/base_detector.py  类初始化里面,


self.model.eval()
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 512, 512).cuda()
res_model = torchvision.models.resnet18()
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(self.model,example)
# save model
traced_script_module.save("./model.pt")
    

另外还需要使得模型是单输出单输出,如果输出矩阵宽高相同forward里面concat操作即可,不同的话,我还没想到什么办法,比如下面这个centernet的修改。

    def forward(self, x):
        x = self.base(x)
        x = self.dla_up(x[self.first_level:])
        #x = self.fc(x)
        #y = self.softmax(self.up(x))
        temp_heads = []
        for head in self.heads:
            temp_head = self.__getattr__(head)(x)
            temp_heads.append(temp_head)
        
        re = torch.cat(temp_heads, 1)
        return re

 

c++工程的验证

    int img_size = 512;
    std::shared_ptr simaRpnNet;
    std::string det_model_path = "/data_1/vir/car_detection/train_ws/CenterNet/c++/model.pt";
    torch::NoGradGuard no_grad;
    simaRpnNet = torch::jit::load(det_model_path);
    simaRpnNet->to(at::kCUDA);
    assert(simaRpnNet != nullptr);
    cout << "[INFO] init model done...\n";

    vector inputs;  //def an input

    cv::Mat src, image, float_image;
    src = cv::imread("/data_1/vir/car_detection/train_ws/CenterNet/images/00d13b4a-0306-4d08-a740-b1b5c63f94c40.jpg");
    cout<(0, 0)[0])<(0, 0)[1])<(0, 0)[2])<(0, 0)[0])<(0, 0)[1])<(0, 0)[2])< rgb
    image.convertTo(float_image, CV_32F, 1.0 / 255);   //归一化到[0,1]区间 TODO
    float *point_img;
    cout<(0, 0)[0]<(0, 0)[1]<(0, 0)[2]<forward(inputs).toTensor();  //前向传播获取结果
    inputs.pop_back();
    cout<<"Forward over!!!"<

核对是否正确的时候碰到的几个问题:

1. c++里面输入的时候不做bgr->rgb python工程里面是做了的。

2. 比较后输出矩阵值完全一样

你可能感兴趣的:(深度学习,环境配置)