Pytorch_模型转Caffe(一)解析caffemodel与prototxt

文章目录

  • Pytorch_模型转Caffe(一)
    • 1.Caffe简介
    • 2.Caffe进行目标检测任务
    • 3.Caffe五大组件
    • 4.caffemodel
    • 5.通过caffemodel解析train.prototxt
    • 6.caffemodel解析现存问题

Pytorch_模型转Caffe(一)

1.Caffe简介

Pytorch_模型转Caffe(一)解析caffemodel与prototxt_第1张图片

2.Caffe进行目标检测任务

  • 利用ssd进行目标检测任务,主要步骤如下(重点是模型的移植)
    Pytorch_模型转Caffe(一)解析caffemodel与prototxt_第2张图片

3.Caffe五大组件

Pytorch_模型转Caffe(一)解析caffemodel与prototxt_第3张图片

4.caffemodel

  • 包含了prototxt(除了solver.prototxt) 和 weights bias
    prototxt 以文本的方式存储网络结构
  • 通过创建caffe_pb2.NetParameter()对象,获取caffemodel内容
model = caffe_pb2.NetParameter()        
f = open(caffemodel_filename, 'rb')
model.ParseFromString(f.read())
  • 循环获取每个layer下的参数
    model.layer是每层的信息
## 逐个解析prototxt 内容 但有点复杂
for i,layer in enumerate(Tarpa_model.layer):
    tops = layer.top
    bottoms = layer.bottom
    top_str = ''
    bottom_str =''
    transform_param_str = ''
    data_param_str    = ''
    annotated_data_param_str=''
    for top in layer.top:
        top_str += '\ttop:"{}"\n'.format(top)
    for bottom in layer.bottom:
        bottom_str += '\tbottom:"{}"\n'.format(bottom)
    # transform 
    if str(layer.transform_param)!='':
        transform_param_str = str(layer.transform_param).split('\n')
        new_str_trans =''
        for item in transform_param_str:
            new_str_trans += '\t\t'+str(item) + '\n' if item!='' else ''
        # print(new_str_trans)
        transform_param_str = '\t' +'transform_param {\n'+ new_str_trans +'\t}'+'\n'
    # data_param
    if str(layer.data_param) != '':
        data_param_str = str(layer.data_param).split('\n')
        new_str_data_param =''
        for item in data_param_str:
            new_str_data_param += '\t\t'+str(item) + '\n' if item!='' else ''

        data_param_str = '\t' +'data_param {\n'+ new_str_data_param +'\t}'+'\n'
    # annotated_data_param
    if str(layer.annotated_data_param) != '':
        annotated_data_param_str = str(layer.annotated_data_param).split('\n')
        new_str_annotated_data_param =''
        for item in annotated_data_param_str:
            new_str_annotated_data_param += '\t\t'+str(item) + '\n' if item!='' else ''
        annotated_data_param_str = '\t' +'annotated_data_param {\n'+ new_str_annotated_data_param +'\t}'+'\n'
  • 解析后的部分结果
### train.prototxt 卷积层
layer {
     
  name: "conv1_2"
  type: "Convolution"
  bottom: "conv1_1"
  top: "conv1_2"
  param {
     
    lr_mult: 1.0
    decay_mult: 1.0
  }
  param {
     
    lr_mult: 2.0
    decay_mult: 0.0
  }
  convolution_param {
     
    num_output: 64
    pad: 1
    kernel_size: 3
    weight_filler {
     
      type: "xavier"
    }
    bias_filler {
     
      type: "constant"
      value: 0.0
    }
  }
}

5.通过caffemodel解析train.prototxt

  • 旨在学习了解caffemodel中的数据存储结构
    采用剔除法,先保存所有layer,之后删除blobs和其他无用信息
import caffe.proto.caffe_pb2 as caffe_pb2
caffemodel_filename = src_path + '/***.caffemodel'
Tarpa_model = caffe_pb2.NetParameter()        
f = open(caffemodel_filename, 'rb')
Tarpa_model.ParseFromString(f.read())
f.close()

print(Tarpa_model.name)
print(Tarpa_model.input)
# print(Tarpa_model.layer)
# print(type(Tarpa_model.layer))
f = open('_caffemodel_.log','w')
f.write('name: "{}"'.format(Tarpa_model.name)+'\n')
for i,layer in enumerate(Tarpa_model.layer):
    transform_param_str = str(layer).split('\n')
    new_str_trans =''
    comtinue_flag = 0
    for item in transform_param_str:
        if item == 'phase: TRAIN':
            continue
        if comtinue_flag and '}'in item:
            continue
        comtinue_flag = 0
        if 'blobs' in item or 'data:'in item or 'shape'in item or 'dim:'in item:
            comtinue_flag = 1
            continue
        new_str_trans += '\t'+str(item) + '\n' if item!='' else ''
    layer_str = 'layer {' +'\n'+\
    new_str_trans+\
    '}'+'\n'
    f.write(str(layer_str))
    print(i)
    # if i ==2:
    #     break
f.close()

6.caffemodel解析现存问题

在生成.prototxt后可以看出有很多split字段,暂未得到解决

layer {
     
	name: "data_data_0_split"
	type: "Split"
	bottom: "data"
	top: "data_data_0_split_0"
	top: "data_data_0_split_1"
	top: "data_data_0_split_2"
	top: "data_data_0_split_3"
	top: "data_data_0_split_4"
	top: "data_data_0_split_5"
	top: "data_data_0_split_6"
	top: "data_data_0_split_7"
}

你可能感兴趣的:(PyTorch,pytorch,caffe,caffemodel)