pytorch 模型转成caffemodel

使用工具:
https://github.com/xxradon/PytorchToCaffe

主要使用pytorch_to_caffe.py
使用例子example 中
将训练好的mobilev2_fcn 成功转成caffemodel


if __name__ == '__main__':
    exit()
    name = 'Anngic_lanenet'
    net = fcn_mobile(11)
    checkpoint = torch.load("/home/pc007/PycharmProjects/Lane_seg_anngic/params/FCN_mobilev2_end17_s=2_equal_anngic_lane_params_init.pkl")
    net.load_state_dict(checkpoint)
    net.eval()
    input = torch.ones([1, 3, 224, 640])
    # input=torch.ones([1,3,224,224])
    pytorch_to_caffe.trans_net(net, input, name)
    pytorch_to_caffe.save_prototxt('{}.prototxt'.format(name))
    pytorch_to_caffe.save_caffemodel('{}.caffemodel'.format(name))

caffe model 可视化:caffemodel 可视化
caffemodel Python接口测试验证模型无偏差:

import caffe
from PIL import Image
import numpy as np
import torch
from torchvision import transforms
import torch.nn.functional as F
import cv2


"""anngic labelme 标注颜色"""
lane_color = np.array([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], [128, 0, 128], [0, 128, 128],
                      [128, 128, 128], [64, 0, 0], [192, 0, 0], [64, 128, 0]])


def color_seg(seg, palette):
    """
    Replace classes with their colors.
    Takes:
        seg: H x W segmentation image of class IDs
    Gives:
        H x W x 3 image of class colors
    """
    return palette[seg.flat].reshape(seg.shape + (3,))


# test caffe model
if __name__ == '__main__':
    exit()
    # 通过以下方法读入网络结构和参数
    deploy_file = "Anngic_lanenet.prototxt"        # caffe的网络架构
    caffemodel_file = "Anngic_lanenet.caffemodel"
    net = caffe.Net(deploy_file, caffemodel_file, caffe.TEST)
    # 查看类型
    # print(type(net))
    # 查看对像内所有属于及方法
    # print(dir(net))
    # print(type(net.inputs))
    # print(type(net.blobs))
    # print(net.blobs)
    # print(type(net.blobs['blob1']))
    # print('==========================2============================')
    # exit()
    # 读入数据
    # im = Image.open("/home/pc007/Desktop/00001.jpg")# 此处为数据处理过程,假设读入图片,经处理之后得到im
    input_img = Image.open("/home/pc007/Desktop/00001.jpg")
    img = input_img.crop((40, 300, 1240, 720))
    img = img.resize((640, 224))  # 输入尺寸与训练时 一致  ,输入为 (1280,720),视野变化,近端车道虚线误差较大。待细究
    input_tensor = transforms.ToTensor()(img).unsqueeze(dim=0).float()
    # im = np.array(img)
    # print(im.shape)
    # im_input = im[np.newaxis, :, :].transpose((0, 3, 1, 2))
    im_input = input_tensor.numpy()
    print(im_input.shape)
    net.blobs['blob1'].data[...] = im_input
    print(net.blobs['blob1'].data.shape)
    # 跑起来吧,caffe!
    net.forward()
    # 把数据经过xxx层后的结果输出来
    # out = net.blobs['conv_transpose_blob3'].data
    # out = net.blobs['conv_transpose_blob3'].data[...]
    out = net.blobs['conv_transpose_blob3'].data[:, :, :, :]
    # print(net.blobs['conv_transpose_blob3'].data.shape)
    pred_index = np.argmax(out, axis=1).squeeze()
    # 转换类型
    print(pred_index.dtype)
    pred_index = pred_index.astype(np.uint8)
    print(pred_index.dtype)
    # np.savez("output.txt", out)
    np.savetxt("output.txt", pred_index)

    # print(out.shape)
    # print(type(out))
    # out = torch.tensor(out)
    # print(type(out))
    # pred_index = torch.argmax(out, dim=1).squeeze().numpy()
    # log_out = F.log_softmax(out, dim=1)
    # label_pred = log_out.max(dim=1)[1].data.numpy()
    # pred_index = label_pred.squeeze()

    print(pred_index.shape)     #(224,640)
    print(pred_index)
    print(type(pred_index))

    np_pred = np.uint8(color_seg(pred_index, lane_color))  # cv2 不能转化 可能是 格式问题 np.uint8
    # img_pred = Image.fromarray(np.uint8(np_pred))
    np_pred_cv = cv2.cvtColor(np_pred, cv2.COLOR_RGB2BGR)
    cv2.imshow("np_pred_cv", np_pred_cv)
    cv2.waitKey(0)
    exit()

较好的学习caffe 博客:

https://blog.csdn.net/langb2014/article/details/53081911
https://blog.csdn.net/Artprog/article/details/79276536
https://blog.csdn.net/warrentdrew/article/details/103496480?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-2.edu_weight&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-2.edu_weight

你可能感兴趣的:(模型互转)