使用工具:
https://github.com/xxradon/PytorchToCaffe
主要使用pytorch_to_caffe.py
使用例子example 中
将训练好的mobilev2_fcn 成功转成caffemodel
if __name__ == '__main__':
exit()
name = 'Anngic_lanenet'
net = fcn_mobile(11)
checkpoint = torch.load("/home/pc007/PycharmProjects/Lane_seg_anngic/params/FCN_mobilev2_end17_s=2_equal_anngic_lane_params_init.pkl")
net.load_state_dict(checkpoint)
net.eval()
input = torch.ones([1, 3, 224, 640])
# input=torch.ones([1,3,224,224])
pytorch_to_caffe.trans_net(net, input, name)
pytorch_to_caffe.save_prototxt('{}.prototxt'.format(name))
pytorch_to_caffe.save_caffemodel('{}.caffemodel'.format(name))
caffe model 可视化:caffemodel 可视化
caffemodel Python接口测试验证模型无偏差:
import caffe
from PIL import Image
import numpy as np
import torch
from torchvision import transforms
import torch.nn.functional as F
import cv2
"""anngic labelme 标注颜色"""
lane_color = np.array([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], [128, 0, 128], [0, 128, 128],
[128, 128, 128], [64, 0, 0], [192, 0, 0], [64, 128, 0]])
def color_seg(seg, palette):
"""
Replace classes with their colors.
Takes:
seg: H x W segmentation image of class IDs
Gives:
H x W x 3 image of class colors
"""
return palette[seg.flat].reshape(seg.shape + (3,))
# test caffe model
if __name__ == '__main__':
exit()
# 通过以下方法读入网络结构和参数
deploy_file = "Anngic_lanenet.prototxt" # caffe的网络架构
caffemodel_file = "Anngic_lanenet.caffemodel"
net = caffe.Net(deploy_file, caffemodel_file, caffe.TEST)
# 查看类型
# print(type(net))
# 查看对像内所有属于及方法
# print(dir(net))
# print(type(net.inputs))
# print(type(net.blobs))
# print(net.blobs)
# print(type(net.blobs['blob1']))
# print('==========================2============================')
# exit()
# 读入数据
# im = Image.open("/home/pc007/Desktop/00001.jpg")# 此处为数据处理过程,假设读入图片,经处理之后得到im
input_img = Image.open("/home/pc007/Desktop/00001.jpg")
img = input_img.crop((40, 300, 1240, 720))
img = img.resize((640, 224)) # 输入尺寸与训练时 一致 ,输入为 (1280,720),视野变化,近端车道虚线误差较大。待细究
input_tensor = transforms.ToTensor()(img).unsqueeze(dim=0).float()
# im = np.array(img)
# print(im.shape)
# im_input = im[np.newaxis, :, :].transpose((0, 3, 1, 2))
im_input = input_tensor.numpy()
print(im_input.shape)
net.blobs['blob1'].data[...] = im_input
print(net.blobs['blob1'].data.shape)
# 跑起来吧,caffe!
net.forward()
# 把数据经过xxx层后的结果输出来
# out = net.blobs['conv_transpose_blob3'].data
# out = net.blobs['conv_transpose_blob3'].data[...]
out = net.blobs['conv_transpose_blob3'].data[:, :, :, :]
# print(net.blobs['conv_transpose_blob3'].data.shape)
pred_index = np.argmax(out, axis=1).squeeze()
# 转换类型
print(pred_index.dtype)
pred_index = pred_index.astype(np.uint8)
print(pred_index.dtype)
# np.savez("output.txt", out)
np.savetxt("output.txt", pred_index)
# print(out.shape)
# print(type(out))
# out = torch.tensor(out)
# print(type(out))
# pred_index = torch.argmax(out, dim=1).squeeze().numpy()
# log_out = F.log_softmax(out, dim=1)
# label_pred = log_out.max(dim=1)[1].data.numpy()
# pred_index = label_pred.squeeze()
print(pred_index.shape) #(224,640)
print(pred_index)
print(type(pred_index))
np_pred = np.uint8(color_seg(pred_index, lane_color)) # cv2 不能转化 可能是 格式问题 np.uint8
# img_pred = Image.fromarray(np.uint8(np_pred))
np_pred_cv = cv2.cvtColor(np_pred, cv2.COLOR_RGB2BGR)
cv2.imshow("np_pred_cv", np_pred_cv)
cv2.waitKey(0)
exit()
较好的学习caffe 博客:
https://blog.csdn.net/langb2014/article/details/53081911
https://blog.csdn.net/Artprog/article/details/79276536
https://blog.csdn.net/warrentdrew/article/details/103496480?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-2.edu_weight&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-2.edu_weight