tensorflow2.0 做数据增强

使用tf2.0做数据增强:
对图片进行反转、旋转等操作。

代码:
需要库

from keras.preprocessing.image import ImageDataGenerator,load_img,img_to_array
import os

此函数执行数据增强功能,前两个参数是需要被做数据增强的两个文件夹地址,后两个参数是做完数据增强后保存图片的地址。

def DataInc(path_cat, path_tiger,to_path_cat, to_path_tiger): 
    #定义图片生成器
    data_gen = ImageDataGenerator(rotation_range=40,
                                  width_shift_range=0.2,
                                  height_shift_range=0.2,
                                  horizontal_flip=True,
                                  vertical_flip=True,
                                  fill_mode='nearest',
                                  data_format='channels_last')
    imgs_cat = os.listdir(path_cat)
    imgs_tiger = os.listdir(path_tiger)
    for img_cat in imgs_cat:
        try:
            img=load_img(path_cat+'\\'+img_cat)
            x = img_to_array(img,data_format="channels_last")   #图片转化成array类型,因flow()接收numpy数组为参数
            x=x.reshape((1,) + x.shape)     #要求为4维

            #使用for循环迭代,生成图片
            i = 0
            for batch in data_gen.flow(x,batch_size=1,
                                       save_to_dir=to_path_cat,
                                       save_prefix='cat',
                                       save_format='jpeg'):
                i += 1
                if i>15:
                    break
        except:
            print('555...Error!')
    for img_tiger in imgs_tiger:
        try:
            img=load_img(path_tiger+'\\'+img_tiger)
            x = img_to_array(img,data_format="channels_last")   #图片转化成array类型,因flow()接收numpy数组为参数
            x=x.reshape((1,) + x.shape)     #要求为4维

            #使用for循环迭代,生成图片
            i = 0
            for batch in data_gen.flow(x,batch_size=1,
                                       save_to_dir=to_path_tiger,
                                       save_prefix='tiger',
                                       save_format='jpeg'):
                i += 1
                if i>20:
                    break
        except:
            print('555...Error!')

执行如下:

path_cat = 'E:\\Machine Learning\\data\\CatAndTiger\\cat'
path_tiger = 'E:\\Machine Learning\\data\\CatAndTiger\\tiger'
to_path_cat = 'E:\\Machine Learning\\data\\CatAndTiger\\cat_inc'
to_path_tiger = 'E:\\Machine Learning\\data\\CatAndTiger\\tiger_inc'
DataInc(path_cat, path_tiger,to_path_cat, to_path_tiger)

你可能感兴趣的:(神经网络)