生成非图像类型的LMDB数据

最近在训练网络中会用到非图像类型的数据,我这里是将这种数据转换成LMDB类型作为一个数据层,加载进网络。主要用到caffe的Python接口。
1、在网络的中间层中,其接受一个1x6维的bottom数据作为输入;
2、每个训练样本对应的1x6维的数据存储到data.txt,同时记录其类别标签;
3、写入LMDB 。

#-*- coding: UTF-8 -*-
import numpy as np 
import caffe
import lmdb
from caffe.proto import caffe_pb2   
import sys,os  


# 读入数据和对应的类别标签
theta_file=open('./data.txt','r')
label=open('./label.txt','r')
theta_list=[]
theta_label=[]
for line in theta_file:
    content=line.strip().split(',')
    theta=[]
    for i in range(len(content)):
        theta.append(float(content[i]))  
    theta_list.append(theta)
    del content,theta
theta_file.close() 

for line in label:
    content=line.strip().split('\n')
    theta_label.append(int(content[0]))

# 写入lmdb,需要将list转换为array
db = lmdb.open('data_lmdb', map_size=int(1e12))
with db.begin(write=True) as in_txn:
    for i in range(len(theta_list)):
        datum = caffe.proto.caffe_pb2.Datum()  
        datum.channels = 1  
        datum.height = 1  
        datum.width = 6
        tmp_=theta_list[i]
        tmp=np.array(range(6), dtype=np.float)
        for j in range(6):
            tmp[j]=tmp_[j]
        label=int(theta_label[i])
        datum.data = tmp.tobytes()
        # datum.data = tmp.tostring() 
        datum.label=label
        in_txn.put('{:0>10d}'.format(i), datum.SerializeToString()) 
db.close()

你可能感兴趣的:(深度学习,caffe)