之前的文章 已经介绍了 matchnet, Siamese network,以及获取他们数据集,此外 我们也改了网络的模型。现在我们离全部工作还差一步之遥,不能放弃。
应该有很多人想要用自己的数据集去训练模型,所以这一篇博文的目的就是 连接数据和网络,制作适合你的数据集。
---------------------------------------------------------------------------------------------------------------
首先,我选择的数据依然是 brown的数据,这个数据集使用比较广泛,想要下载的同学,请查阅第一讲。
def main():
#input arg
args=ParseArgs();
#read 3Dpoint IDs
with open(args.info_file) as f:
point_id=[int(line.split()[0]) for line in f] #对数据切片,去第一个数据
with open(args.interest_file)as f: interest=[[float(x)for x in line.split()] for line in f]
db=leveldb.LevelDB(args.out_db, create_if_missing=True, error_if_exists=True)
#add patches to database
batch=leveldb.WriteBatch()
total=len(interest)
processed=0
for i,metadata in enumerate(interest) #这样 i是interest文件的行数,metadata对应该行的 元数组
datum=caffe_pb2_Datum()
datum.channels, datum.height, datum.width=(2,64,64)
#extract the patch
datum.data=GetPatchImage(i, args.container_dir).tostring())
datum.label=point_id[i]
datum.float_data.extend(metadata)
batch.Put(str(i),datum.serializeToString())
processed+=1;
db.write(batch, sysc=True)
上面的Python 代码 可以看到 datum.data 和datum.label,并且它的label就是point_id,也就是3D点的ID。
def GetPatchImage(patch_id,container_dir):
#deal with dmp
PATCHES_PER_IMAGE=16*16
PATCHES_PER_ROW=16
PATCH_SIZE=64
container_idx,container_offerset=divmod(patch_id,PATCHES_PER_IMAGE)
row_idex,col_ied=divmod(container_ofset,PATCHES_PER_ROW)
#extract the patch from the iamge
patch_image= GetPatchIamge.cached.container_img[\PATCH_SIZE *row_dix:PATCH_SIZE*(row_idx+1),\PATCHE_SIZE*col_idx: PATCH_SIZE * (col_idx+1)]
return patch_image
可以看出 patch_image就是从cached.container_img中按顺序取出一个64*64的patch,这个函数调用是在主函数的循环语句中。