分割数据转成lmdb格式,python代码

分割数据转成lmdb格式

分割数据转成lmdb格式通常有两种方法

1:原图和标签图各生成一个lmdb文件

2:原图和标签图写在一个lmdb文件中

这里采用方法2作为示例,文中以cityscapes数据集为转换对象,请读者根据自己文件夹的实际情况做修改:

import numpy as np
import sys
from PIL import Image
import lmdb
import random
import os
sys.path.append('~/caffe/python/')
import caffe

if __name__ == '__main__':
    k=0
    i=0
    in_db  = lmdb.open('~/cityscape_val_lmdb',map_size=int(1e12)) //存放lmdb文件的位置,为了避免路径问题,请引用的时候都改为绝对路径,下同
    in_txn = in_db.begin(write=True)  //打开写入句柄
    for dirs in  os.listdir('~/cityscape/gtFine/train')://遍历文件夹
        src_dir = os.walk('~/cityscape/gtFine/train/'+dirs)//遍历子文件夹下的文件,src_dir是一个生成器
        for path, dir, img_list in src_dir:
            random.shuffle(img_list)// 乱序子文件,cityscape中的文件都是按城市道路存储的,乱序对于以后的训练很重要
            # creating images lmdb
            for srcimg in img_list:
                if srcimg.endswith('_color.png')://找到原图
                    rgbimg = path + '/' + srcimg//原图的绝对路径
                    labimg = path + '/' + srcimg.split('color')[0] + 'labelIds.png'//找到对应的标签图
                    labname=srcimg.split('color')[0] + 'labelIds.png'//保存标签图的文件名,便于写入lmdb文件
                    if os.path.exists(labimg)://看看标签图是否存在,避免数据集缺损引起训练上的错误
                        rgb = np.array(Image.open(rgbimg))//开始读取原图,下面的操作为转换格式
                        Dtype = rgb.dtype
                        rgb= rgb[:,:,::-1]
                        rgb = Image.fromarray(rgb)
                        rgb = np.array(rgb, Dtype)
                        rgb = rgb.transpose((2, 0, 1))
                        rgb_dat = caffe.io.array_to_datum(rgb)
                        in_txn.put(srcimg, rgb_dat.SerializeToString())

                        lab = np.array(Image.open(labimg), Dtype)//开始读取标签图
                        lab = Image.fromarray(lab)
                        lab = np.array(lab, Dtype)
                        lab = lab.reshape(lab.shape[0],lab.shape[1],1)
                        lab = lab.transpose((2,0,1))
                        L_dat = caffe.io.array_to_datum(lab)  //将数据存储在datum中
                        in_txn.put(labname, L_dat.SerializeToString())//序列化数据,写入lmdb文件
                        i += 1
                        if i%100 == 0://100作为一个batchsize,提交数据到lmdb文件,初始化句柄,避免内存爆掉
                            k=k+1
                            in_txn.commit()
                            in_txn = in_db.begin(write=True)
                            print 'process %d batch' % k
    in_txn.commit()//写入最后一个batch到lmdb文件
    print 'process last batch!'
    in_db.close()
    print "finish!!!"

你可能感兴趣的:(data,processing,tool)