图像数据预处理 -- 数据增强、写入tfrecords

Augmentor 是图像数据增强一个很好用的python库,支持多种图像变形变换。

  • 下面这段代码展示的是基于图像分割的数据集,同时生成增强的图像及其对应的label
import Augmentor


# 图像所在目录
AUGMENT_SOURCE_DIR = 'E:/datasets/leafs/imgs'
AUGMENT_LABEL_DIR = 'E:/datasets/leafs/lbls'

# 增强的图像的保存目录,好像只支持绝对路径
AUGMENT_OUTPUT_DIR = 'E:/datasets/leafs/img_aug'


def augment():
    p = Augmentor.Pipeline(
            source_directory=AUGMENT_SOURCE_DIR,
            output_directory=AUGMENT_OUTPUT_DIR
    )
    # 图片对应的标签的目录,且二者必须同名(要自己预处理一下)
    p.ground_truth(ground_truth_directory=AUGMENT_LABEL_DIR)
    # 旋转:概率0.3
    p.rotate(probability=0.3, max_left_rotation=2, max_right_rotation=2)
    # 缩放
    p.zoom(probability=0.3, min_factor=1.1, max_factor=1.2)
    # 歪斜
    p.skew(probability=0.3)
    # 扭曲,注意grid_width, grid_height 不能超过原图
    p.random_distortion(probability=0.3, grid_width=20, grid_height=20, magnitude=1)
    # 四周裁剪
    p.shear(probability=0.3, max_shear_left=2, max_shear_right=2)
    # 随机裁剪
    p.crop_random(probability=0.3, percentage_area=0.8)
    # 翻转
    p.flip_random(probability=0.3)
    # 生成多少增强的图片
    p.sample(n=8100)


# 分离image 和 label
def dispatch():
    root_dir = 'E:/datasets/leafs/img_aug'
    img_out = 'E:/datasets/leafs/images'
    lbl_out = 'E:/datasets/leafs/labels'
    cnt = 0
    files = os.listdir(root_dir)
    for filename in files:
        if filename.startswith('_groundtruth'):
            lbl_path = os.path.join(root_dir, filename)
            img_path = os.path.join(root_dir, filename.replace('_groundtruth_(1)_imgs_', 'imgs_original_'))
            cnt += 1
            shutil.copyfile(img_path, os.path.join(img_out, '%d.png' % cnt))
            shutil.copyfile(lbl_path, os.path.join(lbl_out, '%d.png' % cnt))
    print(cnt)

 

  • 上述操作之后图像和标签会同时生成在同一文件夹(AUGMENT_OUTPUT_DIR)下面,其图像和对应的label命名是对应的,所以下面将二者分别转移到各自的文件夹下:
def standard_img_and_lbl(dir):
    filenames = glob.glob(dir + '/*.png')
    for idx, filename in enumerate(filenames):
        if 'image_original' in filename:
            label_name = filename.replace('image_original_', '_groundtruth_(1)_image_')
            img = cv2.imread(filename)
            lbl = cv2.imread(label_name)
            cv2.imwrite(os.path.join(AUGMENT_IMAGE_PATH, '%d.png'%idx), img)
            cv2.imwrite(os.path.join(AUGMENT_LABEL_PATH, '%d.png'%idx), lbl)

 

  • 将图像写成TFRecords形式保存:TFRecords文件是一种二进制文件,其不对数据进行压缩,所以可以被快速加载到内存中.格式不支持随机访问,因此它适合于大量的数据流,但不适用于快速分片或其他非连续存取
def write_image_to_tfrecords():
    # image / label 各自的存储文件夹
    augment_image_path = AUGMENT_IMAGE_PATH
    augment_label_path = AUGMENT_LABEL_PATH
    # 要生成的文件:train、validation、predict
    train_set_writer = tf.python_io.TFRecordWriter(os.path.join('./dataset/my_set', TRAIN_SET_NAME))
    validation_set_writer = tf.python_io.TFRecordWriter(os.path.join('./dataset/my_set', VALIDATION_SET_NAME))
    predict_set_writer = tf.python_io.TFRecordWriter(os.path.join('./dataset/my_set', PREDICT_SET_NAME))

    # train set
    for idx in range(TRAIN_SET_SIZE):
        train_image = cv2.imread(os.path.join(augment_image_path, '%d.png' % idx))
        train_label = cv2.imread(os.path.join(augment_label_path, '%d.png' % idx), 0)
        train_image = cv2.resize(train_image, (INPUT_WIDTH, INPUT_HEIGHT))
        train_label = cv2.resize(train_label, (INPUT_WIDTH, INPUT_HEIGHT))
        train_label[train_label != 0] = 1
        example = tf.train.Example(features=tf.train.Features(feature={
            'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_label.tobytes()])),
            'image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_image.tobytes()]))
        }))     # example对象对label和image数据进行封装
        train_set_writer.write(example.SerializeToString())
        if idx % 100 == 0:
            print('Done train_set writing %.2f%%' % (idx / TRAIN_SET_SIZE * 100))
    train_set_writer.close()
    print('Done test set writing.')

    # validation set
    for idx in range(TRAIN_SET_SIZE, TRAIN_SET_SIZE + VALIDATION_SET_SIZE):
        validation_image = cv2.imread(os.path.join(augment_image_path, '%d.png' % idx))
        validation_label = cv2.imread(os.path.join(augment_label_path, '%d.png' % idx), 0)
        validation_image = cv2.resize(validation_image, (INPUT_WIDTH, INPUT_HEIGHT))
        validation_label = cv2.resize(validation_label, (INPUT_WIDTH, INPUT_HEIGHT))
        validation_label[validation_label != 0] = 1

        example = tf.train.Example(features=tf.train.Features(feature={
            'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[validation_label.tobytes()])),
            'image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[validation_image.tobytes()]))
        }))
        validation_set_writer.write(example.SerializeToString())  # 序列化为字符串
        if idx % 10 == 0:
            print('Done validation_set writing %.2f%%' % ((idx - TRAIN_SET_SIZE) / VALIDATION_SET_SIZE * 100))
    validation_set_writer.close()
    print("Done validation_set writing")

    # predict set
    predict_image_path = ORIGIN_PREDICT_IMG_DIR
    predict_label_path = ORIGIN_PREDICT_LBL_DIR
    for idx in range(PREDICT_SET_SIZE):
        predict_image = cv2.imread(os.path.join(predict_image_path, '%d.png'%idx))
        predict_label = cv2.imread(os.path.join(predict_label_path, '%d.png'%idx), 0)
        predict_image = cv2.resize(predict_image, (INPUT_WIDTH, INPUT_HEIGHT))
        predict_label = cv2.resize(predict_label, (OUTPUT_WIDTH, OUTPUT_HEIGHT))
        predict_label[predict_label != 0] = 1
        example = tf.train.Example(features=tf.train.Features(feature={
            'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[predict_label.tobytes()])),
            'image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[predict_image.tobytes()]))
        }))
        predict_set_writer.write(example.SerializeToString())
        if idx % 10 == 0:
            print('Done predict_set writing %.2f%%' % (idx / PREDICT_SET_SIZE * 100))
    predict_set_writer.close()
    print("Done predict_set writing")

 

  • 读取并验证TFRecords文件是否存储正确:
INPUT_WIDTH, INPUT_HEIGHT, INPUT_CHANNEL = 512, 512, 3
OUTPUT_WIDTH, OUTPUT_HEIGHT, OUTPUT_CHANNEL = 512, 512, 1
TRAIN_SET_NAME = 'train_set.tfrecords'
TFRECORDS_DIR = './dataset/my_set'


# 读取图像及其对应的label
def read_image(file_queue):
    # 用于读取TFRecord的类
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(file_queue)
    # 解析文件
    features = tf.parse_single_example(
            serialized_example,
            features={
                'label': tf.FixedLenFeature([], tf.string),
                'image_raw': tf.FixedLenFeature([], tf.string)
            }
    )
    # 解码为 uint8 的图像格式
    image = tf.decode_raw(features['image_raw'], tf.uint8)
    image = tf.reshape(image, [INPUT_WIDTH, INPUT_HEIGHT, INPUT_CHANNEL])
    label = tf.decode_raw(features['label'], tf.uint8)
    label = tf.reshape(label, [OUTPUT_WIDTH, OUTPUT_HEIGHT])
    return image, label


# 显示图像和label
def read_check_tfrecords():
    train_file_path = os.path.join(TFRECORDS_DIR, TRAIN_SET_NAME)
    train_image_filename_queue = tf.train.string_input_producer(
            string_tensor=tf.train.match_filenames_once(train_file_path),
            num_epochs=1,
            shuffle=True
    )
    train_images, train_labels = read_image(train_image_filename_queue)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        example, label = sess.run([train_images, train_labels])
        cv2.imshow('image', example)
        cv2.imshow('label', label)
        cv2.waitKey(0)
        coord.request_stop()
        coord.join(threads)
    print('Done reading and checking.')
    
# read_check_tfrecords()

 

你可能感兴趣的:(图像数据预处理 -- 数据增强、写入tfrecords)