LMDB和SQLite/MySQL等关系型数据库不同,属于key-value数据库(把LMDB想成dict会比较容易理解),键key与值value都是字符串。
pip install lmdb
###操作流程
1.创建lmdb环境
env = lmdb.open()
2.建立事务
txn = env.begin()
3.向事务中写入或者修改数据
txn.put(key, value)
4. 删除数据
txn.delete(key)
5. 数据查询
txn.get(key)
6. 数据遍历
txn.cursor()
7. 数据提交
txn.commit()
在进行OCR文本识别的过程中训练的数据量较大,所以采用将数据保存为LMDB数据,指定图片的主路径与标签文件
如下所示:
# coding:utf-8
import os
import lmdb # install lmdb by "pip install lmdb"
import cv2
import re
from PIL import Image
import numpy as np
import imghdr
import argparse
from tqdm import tqdm
def init_args():
args = argparse.ArgumentParser()
args.add_argument('-i',
'--image_dir',
default='',
type=str,
help='The directory of the dataset , which contains the images')
args.add_argument('-l',
'--label_file',
default='/datassd/hzl/text_render_data/mingpian/new_gray.txt',
type=str,
help='The file which contains the paths and the labels of the data set')
args.add_argument('-s',
'--save_dir',
default='/datassd/hzl/text_render_data/mingpian/lmdb_gray/',
type=str
, help='The generated mdb file save dir')
args.add_argument('-m',
'--map_size',
help='map size of lmdb',
type=int,
default=40000000000000)
return args.parse_args()
def checkImageIsValid(imageBin):
if imageBin is None:
return False
try:
imageBuf = np.frombuffer(imageBin, dtype=np.uint8)
img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
imgH, imgW = img.shape[0], img.shape[1]
except:
return False
else:
if imgH * imgW == 0:
return False
return True
def writeCache(env, cache):
with env.begin(write=True) as txn: # 建立事务
for k, v in cache.items():
if type(k) == str:
k = k.encode()
if type(v) == str:
v = v.encode()
txn.put(k,v) #写入数据
def createDataset(outputPath, imagePathList, labelList, map_size, lexiconList=None, checkValid=True):
"""
Create LMDB dataset for CRNN training.
ARGS:
outputPath : LMDB output path
imagePathList : list of image path
labelList : list of corresponding groundtruth texts
lexiconList : (optional) list of lexicon lists
checkValid : if true, check the validity of every image
"""
assert (len(imagePathList) == len(labelList))
nSamples = len(imagePathList)
env = lmdb.open(outputPath, map_size=map_size) #创建lmdb环境
# env = lmdb.open(outputPath)
cache = {}
cnt = 0
for i in tqdm(range(nSamples)):
imagePath = imagePathList[i].replace('\n', '').replace('\r\n', '')
label = labelList[i]
with open(imagePath, 'rb') as f:
imageBin = f.read()
if checkValid:
if not checkImageIsValid(imageBin):
print('%s is not a valid image' % imagePath)
continue
imageKey = 'image-%09d' % cnt
labelKey = 'label-%09d' % cnt
cache[imageKey] = imageBin
cache[labelKey] = label
if lexiconList:
lexiconKey = 'lexicon-%09d' % cnt
cache[lexiconKey] = ' '.join(lexiconList[i])
if cnt != 0 and cnt % 1000 == 0:
writeCache(env, cache) # 写入数据
cache = {}
print('Written %d / %d' % (cnt, nSamples))
cnt += 1
cache['num-samples'] = str(nSamples)
writeCache(env, cache)
env.close() # 关闭事务
print('Created dataset with %d samples' % nSamples)
if __name__ == '__main__':
args = init_args() # 初始化参数
imgdata = open(args.label_file, mode='r')
lines = list(imgdata) #获取标签列表
imgDir = args.image_dir
imgPathList = []
labelList = []
# 将标签中的文件读取,并过滤不正常的数据,保存为图片和标签的list
for i, line in enumerate(lines):
#imgPath = os.path.join(imgDir, line.split()[0].decode('utf-8'))
#print(line.strip().split('\t'))
if line.strip() == '':
continue
if ' ' in line:
imgPath, word = line.strip('\n').strip().split('\t')
else:
imgPath, word = line.strip('\n').strip().split()
if not os.path.exists(imgPath):
continue
imgPathList.append(imgPath)
labelList.append(word)
# 写入lmdb
createDataset(args.save_dir, imgPathList, labelList, args.map_size)
import lmdb
import six
import sys
from PIL import Image
import cv2
import numpy as np
from lib.dataset.transformers import *
class LMDB(Dataset):
def __init__(self, root=None, is_train=True):
if is_train:
root = config.DATASET.TRAIN_FILE # 如果训练的话加载训练的lmdb文件,如果验证的话加载测试集lmdb文件
else:
root = config.DATASET.TEST_FILE
root = config.LMDB_ROOT
trainsform = config.DATASET.TRANSFORM
target_transform = config.DATASET.TARGET_TRANSFORM
self.env = lmdb.open(
root,
max_readers=1,
readonly=True,
lock=False,
readahead=False,
meminit=False,
map_size=40000000000000)
if not self.env:
print('cannot creat lmdb from %s' % (root))
sys.exit(0)
with self.env.begin(write=False) as txn:
str = 'num-samples'.encode('utf-8')
nSamples = int(txn.get(str))
self.nSamples = nSamples
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return self.nSamples
def __getitem__(self, index):
assert index <= len(self), 'index range error'
index += 1
with self.env.begin(write=False) as txn:
img_key = 'image-%09d' % index
imgbuf = txn.get(img_key.encode('utf-8'))
try:
buf = six.BytesIO()
buf.write(imgbuf)
buf.seek(0)
img = Image.open(buf)
if self.target_transform == None:
img = cv2.cvtColor(np.asarray(img),cv2.COLOR_RGB2BGR) # 转opencv增强
img = random_transformers(img) # 随机增强
img = Image.fromarray(cv2.cvtColor(img,cv2.COLOR_BGR2RGB)) # 转回PIL
img = np.array(img)
# img = generate_image(img)
# img = Image.fromarray(np.uint8(img)).convert('L')
img = Image.fromarray(np.uint8(img))
except:
# traceback.print_exc()
# print('Corrupted image for %d' % index)
# return self[index + 1]
return self[1]
if self.transform is not None:
img = self.transform(img)
label_key = 'label-%09d' % index
label = txn.get(label_key.encode())
# if len(set(label.decode('utf-8')) - set(alphabets.alphabet)) != 0:
label = label.decode('utf-8').replace(' ', '').replace('¥', '¥').encode('utf-8')
# print(label)
# return self[1]
if self.target_transform is not None:
label = self.target_transform(label)
return img, label