python lmdb使用

python lmdb使用

    • python lmdb使用
      • 安装
      • 项目实际用例
      • pytorch dataload方法从写

python lmdb使用

LMDB和SQLite/MySQL等关系型数据库不同,属于key-value数据库(把LMDB想成dict会比较容易理解),键key与值value都是字符串。

安装

pip install lmdb

###操作流程
1.创建lmdb环境
env = lmdb.open()
2.建立事务
txn = env.begin()
3.向事务中写入或者修改数据
txn.put(key, value)
4. 删除数据
txn.delete(key)
5. 数据查询
txn.get(key)
6. 数据遍历
txn.cursor()
7. 数据提交
txn.commit()

项目实际用例

在进行OCR文本识别的过程中训练的数据量较大,所以采用将数据保存为LMDB数据,指定图片的主路径与标签文件

如下所示:

# coding:utf-8
import os
import lmdb  # install lmdb by "pip install lmdb"
import cv2
import re
from PIL import Image
import numpy as np
import imghdr
import argparse
from tqdm import tqdm


def init_args():
    args = argparse.ArgumentParser()
    args.add_argument('-i',
                      '--image_dir',
                      default='',
                      type=str,
                      help='The directory of the dataset , which contains the images')
    args.add_argument('-l',
                      '--label_file',
                      default='/datassd/hzl/text_render_data/mingpian/new_gray.txt',
                      type=str,
                      help='The file which contains the paths and the labels of the data set')
    args.add_argument('-s',
                      '--save_dir',
                      default='/datassd/hzl/text_render_data/mingpian/lmdb_gray/',
                      type=str
                      , help='The generated mdb file save dir')
    args.add_argument('-m',
                      '--map_size',
                      help='map size of lmdb',
                      type=int,
                      default=40000000000000)

    return args.parse_args()


def checkImageIsValid(imageBin):
    if imageBin is None:
        return False
    try:
        imageBuf = np.frombuffer(imageBin, dtype=np.uint8)
        img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
        imgH, imgW = img.shape[0], img.shape[1]
    except:
        return False
    else:
        if imgH * imgW == 0:
            return False
    return True


def writeCache(env, cache):
    with env.begin(write=True) as txn: # 建立事务
        for k, v in cache.items():
            if type(k) == str:
                k = k.encode()
            if type(v) == str:
                v = v.encode()
            txn.put(k,v) #写入数据

def createDataset(outputPath, imagePathList, labelList, map_size, lexiconList=None, checkValid=True):
    """
    Create LMDB dataset for CRNN training.
    ARGS:
        outputPath    : LMDB output path
        imagePathList : list of image path
        labelList     : list of corresponding groundtruth texts
        lexiconList   : (optional) list of lexicon lists
        checkValid    : if true, check the validity of every image
    """
    assert (len(imagePathList) == len(labelList))
    nSamples = len(imagePathList)
    env = lmdb.open(outputPath, map_size=map_size) #创建lmdb环境
    # env = lmdb.open(outputPath)
    cache = {}
    cnt = 0
    for i in tqdm(range(nSamples)):
        imagePath = imagePathList[i].replace('\n', '').replace('\r\n', '')
        label = labelList[i]
        with open(imagePath, 'rb') as f:
            imageBin = f.read()
        if checkValid:
            if not checkImageIsValid(imageBin):
                print('%s is not a valid image' % imagePath)
                continue
        imageKey = 'image-%09d' % cnt
        labelKey = 'label-%09d' % cnt
        cache[imageKey] = imageBin
        cache[labelKey] = label
        if lexiconList:
            lexiconKey = 'lexicon-%09d' % cnt
            cache[lexiconKey] = ' '.join(lexiconList[i])
        if cnt != 0 and cnt % 1000 == 0:
            writeCache(env, cache) # 写入数据
            cache = {}
            print('Written %d / %d' % (cnt, nSamples))
        cnt += 1

    cache['num-samples'] = str(nSamples)
    writeCache(env, cache)
    env.close() # 关闭事务
    print('Created dataset with %d samples' % nSamples)


if __name__ == '__main__':
    args = init_args() # 初始化参数
    imgdata = open(args.label_file, mode='r')
    lines = list(imgdata) #获取标签列表

    imgDir = args.image_dir
    imgPathList = []
    labelList = []
    # 将标签中的文件读取,并过滤不正常的数据,保存为图片和标签的list
    for i, line in enumerate(lines):
        #imgPath = os.path.join(imgDir, line.split()[0].decode('utf-8'))
        #print(line.strip().split('\t'))
        if line.strip() == '':
            continue
        if ' ' in line:
            imgPath, word = line.strip('\n').strip().split('\t')
        else:
            imgPath, word = line.strip('\n').strip().split()
        if not os.path.exists(imgPath):
            continue
        imgPathList.append(imgPath)
        labelList.append(word)
    # 写入lmdb
    createDataset(args.save_dir, imgPathList, labelList, args.map_size)

pytorch dataload方法从写

import lmdb
import six
import sys
from PIL import Image
import cv2
import numpy as np
from lib.dataset.transformers import *


class LMDB(Dataset):
    def __init__(self, root=None, is_train=True):
        if is_train:
            root = config.DATASET.TRAIN_FILE # 如果训练的话加载训练的lmdb文件,如果验证的话加载测试集lmdb文件
        else:
            root = config.DATASET.TEST_FILE
        root = config.LMDB_ROOT
        trainsform = config.DATASET.TRANSFORM
        target_transform = config.DATASET.TARGET_TRANSFORM
        self.env = lmdb.open(
            root,
            max_readers=1,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
            map_size=40000000000000)

        if not self.env:
            print('cannot creat lmdb from %s' % (root))
            sys.exit(0)

        with self.env.begin(write=False) as txn:
            str = 'num-samples'.encode('utf-8')
            nSamples = int(txn.get(str))
            self.nSamples = nSamples

        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return self.nSamples  

    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        index += 1
        with self.env.begin(write=False) as txn:
            img_key = 'image-%09d' % index
            imgbuf = txn.get(img_key.encode('utf-8'))

            try:
                buf = six.BytesIO()
                buf.write(imgbuf)
                buf.seek(0)
                img = Image.open(buf)
                if self.target_transform == None:
                    img = cv2.cvtColor(np.asarray(img),cv2.COLOR_RGB2BGR)   # 转opencv增强
                    img = random_transformers(img)             # 随机增强
                    img = Image.fromarray(cv2.cvtColor(img,cv2.COLOR_BGR2RGB))   # 转回PIL

                img = np.array(img)
                # img = generate_image(img)
                # img = Image.fromarray(np.uint8(img)).convert('L')
                img = Image.fromarray(np.uint8(img))
            except:
                # traceback.print_exc()  
                # print('Corrupted image for %d' % index)
                # return self[index + 1]
                return self[1]

            if self.transform is not None:
                img = self.transform(img)

            label_key = 'label-%09d' % index
            label = txn.get(label_key.encode())
            # if len(set(label.decode('utf-8')) - set(alphabets.alphabet)) != 0:
            label = label.decode('utf-8').replace(' ', '').replace('¥', '¥').encode('utf-8')
            # print(label)
            # return self[1]
            if self.target_transform is not None:
                label = self.target_transform(label)
        return img, label

你可能感兴趣的:(计算机视觉,python,深度学习)