推荐一个无比强大的数据增强库,mask图像同时增强

朋友推荐的,好用的一笔。

地址:https://github.com/aleju/imgaug#list_of_augmenters推荐一个无比强大的数据增强库,mask图像同时增强_第1张图片

送一份我用的代码

'''
@Author: haoMax
@Github: https://github.com/liuzehao
@Blog: https://blog.csdn.net/liu506039293
@Date: 2019-11-27 11:44:45
@LastEditTime: 2019-11-27 14:01:14
@LastEditors: haoMax
@Description: 
'''
import numpy as np
import imgaug.augmenters as iaa
import os
import cv2
PIXEL=512
class_num=4
X_CHANNEL=3
Y_CHANNEL=1
def get_all_files(bg_path):
    files = []

    for f in os.listdir(bg_path):
        if os.path.isfile(os.path.join(bg_path, f)):
            files.append(os.path.join(bg_path, f))
        else:
            files.extend(get_all_files(os.path.join(bg_path, f)))
    files.sort(key=lambda x: int(x[-8:-4]))#排序从小到大
    return files
def generator(pathX, pathY,BATCH_SIZE,NUM):
    X_train_files = get_all_files(pathX)
    Y_train_files = get_all_files(pathY)
    a = (np.arange(1, NUM))
    # print(a)
    # cnt = 0
    X = []
    Y = []
    for i in range(BATCH_SIZE):
        index = np.random.choice(a)
        # print(index)
        # print(X_train_files[index])
        img = cv2.imread(X_train_files[index], 1)
        img=cv2.resize(img,(PIXEL,PIXEL))
        # cv2.imshow("a",img)
        # cv2.waitKey(0)
        # print(np.array(img).shape)
        # print(pathX + str(i+1)+'.png')
        #
        # img = img / 255  # normalization
        img = np.array(img).reshape(PIXEL, PIXEL, X_CHANNEL)
        X.append(img)
        img1 = cv2.imread(Y_train_files[index], 0)
        # print(img1.shape)
        img1=cv2.resize(img1,(PIXEL,PIXEL))
        # img1 = img1 / 255  # normalization
        img1 = np.array(img1).reshape(PIXEL, PIXEL,Y_CHANNEL)

        Y.append(img1)

    X = np.array(X)
    Y = np.array(Y)
    return X, Y
pathX='./images/training'
pathY='./annotations/training'
out_images='./augment/images'
out_images='./augment/annotations'
NUM=326
BATCH_SIZE=10
epoch=100

for i in range(epoch):
    images,segmaps=generator(pathX,pathY,BATCH_SIZE,NUM)
    seq = iaa.Sequential([
    iaa.PerspectiveTransform(scale=(0.01, 0.1)),
    iaa.Fliplr(0.5),
    iaa.Flipud(0.5)
    ])
    images_aug, segmaps_aug = seq(images=images, segmentation_maps=segmaps)
    for i in range(BATCH_SIZE):
        cv2.imshow('a',images_aug[i])
        cv2.imshow('b',segmaps_aug[i]*85)
        cv2.waitKey(0)

 

你可能感兴趣的:(AR深度学习项目)