Augmentor 是图像数据增强一个很好用的python库,支持多种图像变形变换。
import Augmentor
# 图像所在目录
AUGMENT_SOURCE_DIR = 'E:/datasets/leafs/imgs'
AUGMENT_LABEL_DIR = 'E:/datasets/leafs/lbls'
# 增强的图像的保存目录,好像只支持绝对路径
AUGMENT_OUTPUT_DIR = 'E:/datasets/leafs/img_aug'
def augment():
p = Augmentor.Pipeline(
source_directory=AUGMENT_SOURCE_DIR,
output_directory=AUGMENT_OUTPUT_DIR
)
# 图片对应的标签的目录,且二者必须同名(要自己预处理一下)
p.ground_truth(ground_truth_directory=AUGMENT_LABEL_DIR)
# 旋转:概率0.3
p.rotate(probability=0.3, max_left_rotation=2, max_right_rotation=2)
# 缩放
p.zoom(probability=0.3, min_factor=1.1, max_factor=1.2)
# 歪斜
p.skew(probability=0.3)
# 扭曲,注意grid_width, grid_height 不能超过原图
p.random_distortion(probability=0.3, grid_width=20, grid_height=20, magnitude=1)
# 四周裁剪
p.shear(probability=0.3, max_shear_left=2, max_shear_right=2)
# 随机裁剪
p.crop_random(probability=0.3, percentage_area=0.8)
# 翻转
p.flip_random(probability=0.3)
# 生成多少增强的图片
p.sample(n=8100)
# 分离image 和 label
def dispatch():
root_dir = 'E:/datasets/leafs/img_aug'
img_out = 'E:/datasets/leafs/images'
lbl_out = 'E:/datasets/leafs/labels'
cnt = 0
files = os.listdir(root_dir)
for filename in files:
if filename.startswith('_groundtruth'):
lbl_path = os.path.join(root_dir, filename)
img_path = os.path.join(root_dir, filename.replace('_groundtruth_(1)_imgs_', 'imgs_original_'))
cnt += 1
shutil.copyfile(img_path, os.path.join(img_out, '%d.png' % cnt))
shutil.copyfile(lbl_path, os.path.join(lbl_out, '%d.png' % cnt))
print(cnt)
def standard_img_and_lbl(dir):
filenames = glob.glob(dir + '/*.png')
for idx, filename in enumerate(filenames):
if 'image_original' in filename:
label_name = filename.replace('image_original_', '_groundtruth_(1)_image_')
img = cv2.imread(filename)
lbl = cv2.imread(label_name)
cv2.imwrite(os.path.join(AUGMENT_IMAGE_PATH, '%d.png'%idx), img)
cv2.imwrite(os.path.join(AUGMENT_LABEL_PATH, '%d.png'%idx), lbl)
def write_image_to_tfrecords():
# image / label 各自的存储文件夹
augment_image_path = AUGMENT_IMAGE_PATH
augment_label_path = AUGMENT_LABEL_PATH
# 要生成的文件:train、validation、predict
train_set_writer = tf.python_io.TFRecordWriter(os.path.join('./dataset/my_set', TRAIN_SET_NAME))
validation_set_writer = tf.python_io.TFRecordWriter(os.path.join('./dataset/my_set', VALIDATION_SET_NAME))
predict_set_writer = tf.python_io.TFRecordWriter(os.path.join('./dataset/my_set', PREDICT_SET_NAME))
# train set
for idx in range(TRAIN_SET_SIZE):
train_image = cv2.imread(os.path.join(augment_image_path, '%d.png' % idx))
train_label = cv2.imread(os.path.join(augment_label_path, '%d.png' % idx), 0)
train_image = cv2.resize(train_image, (INPUT_WIDTH, INPUT_HEIGHT))
train_label = cv2.resize(train_label, (INPUT_WIDTH, INPUT_HEIGHT))
train_label[train_label != 0] = 1
example = tf.train.Example(features=tf.train.Features(feature={
'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_label.tobytes()])),
'image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_image.tobytes()]))
})) # example对象对label和image数据进行封装
train_set_writer.write(example.SerializeToString())
if idx % 100 == 0:
print('Done train_set writing %.2f%%' % (idx / TRAIN_SET_SIZE * 100))
train_set_writer.close()
print('Done test set writing.')
# validation set
for idx in range(TRAIN_SET_SIZE, TRAIN_SET_SIZE + VALIDATION_SET_SIZE):
validation_image = cv2.imread(os.path.join(augment_image_path, '%d.png' % idx))
validation_label = cv2.imread(os.path.join(augment_label_path, '%d.png' % idx), 0)
validation_image = cv2.resize(validation_image, (INPUT_WIDTH, INPUT_HEIGHT))
validation_label = cv2.resize(validation_label, (INPUT_WIDTH, INPUT_HEIGHT))
validation_label[validation_label != 0] = 1
example = tf.train.Example(features=tf.train.Features(feature={
'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[validation_label.tobytes()])),
'image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[validation_image.tobytes()]))
}))
validation_set_writer.write(example.SerializeToString()) # 序列化为字符串
if idx % 10 == 0:
print('Done validation_set writing %.2f%%' % ((idx - TRAIN_SET_SIZE) / VALIDATION_SET_SIZE * 100))
validation_set_writer.close()
print("Done validation_set writing")
# predict set
predict_image_path = ORIGIN_PREDICT_IMG_DIR
predict_label_path = ORIGIN_PREDICT_LBL_DIR
for idx in range(PREDICT_SET_SIZE):
predict_image = cv2.imread(os.path.join(predict_image_path, '%d.png'%idx))
predict_label = cv2.imread(os.path.join(predict_label_path, '%d.png'%idx), 0)
predict_image = cv2.resize(predict_image, (INPUT_WIDTH, INPUT_HEIGHT))
predict_label = cv2.resize(predict_label, (OUTPUT_WIDTH, OUTPUT_HEIGHT))
predict_label[predict_label != 0] = 1
example = tf.train.Example(features=tf.train.Features(feature={
'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[predict_label.tobytes()])),
'image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[predict_image.tobytes()]))
}))
predict_set_writer.write(example.SerializeToString())
if idx % 10 == 0:
print('Done predict_set writing %.2f%%' % (idx / PREDICT_SET_SIZE * 100))
predict_set_writer.close()
print("Done predict_set writing")
INPUT_WIDTH, INPUT_HEIGHT, INPUT_CHANNEL = 512, 512, 3
OUTPUT_WIDTH, OUTPUT_HEIGHT, OUTPUT_CHANNEL = 512, 512, 1
TRAIN_SET_NAME = 'train_set.tfrecords'
TFRECORDS_DIR = './dataset/my_set'
# 读取图像及其对应的label
def read_image(file_queue):
# 用于读取TFRecord的类
reader = tf.TFRecordReader()
_, serialized_example = reader.read(file_queue)
# 解析文件
features = tf.parse_single_example(
serialized_example,
features={
'label': tf.FixedLenFeature([], tf.string),
'image_raw': tf.FixedLenFeature([], tf.string)
}
)
# 解码为 uint8 的图像格式
image = tf.decode_raw(features['image_raw'], tf.uint8)
image = tf.reshape(image, [INPUT_WIDTH, INPUT_HEIGHT, INPUT_CHANNEL])
label = tf.decode_raw(features['label'], tf.uint8)
label = tf.reshape(label, [OUTPUT_WIDTH, OUTPUT_HEIGHT])
return image, label
# 显示图像和label
def read_check_tfrecords():
train_file_path = os.path.join(TFRECORDS_DIR, TRAIN_SET_NAME)
train_image_filename_queue = tf.train.string_input_producer(
string_tensor=tf.train.match_filenames_once(train_file_path),
num_epochs=1,
shuffle=True
)
train_images, train_labels = read_image(train_image_filename_queue)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
example, label = sess.run([train_images, train_labels])
cv2.imshow('image', example)
cv2.imshow('label', label)
cv2.waitKey(0)
coord.request_stop()
coord.join(threads)
print('Done reading and checking.')
# read_check_tfrecords()