tensorflow2.x学习笔记七:ImageDataGenerator的使用

tf.keras.preprocessing.image.ImageDataGenerator(
    featurewise_center=False, samplewise_center=False,
    featurewise_std_normalization=False, 
    samplewise_std_normalization=False,
    zca_whitening=False, zca_epsilon=1e-06, rotation_range=0, 
    width_shift_range=0.0,height_shift_range=0.0, 
    brightness_range=None, shear_range=0.0, zoom_range=0.0,
    channel_shift_range=0.0, fill_mode='nearest', cval=0.0, 
    horizontal_flip=False,vertical_flip=False, rescale=None, 
    preprocessing_function=None,
    data_format=None, validation_split=0.0, dtype=None
)

该类初始化时的参数太多,我就不写了(懒),可以直接去官方API文档去看。

下面举例说明一下怎么使用ImageDataGenerator

1、flow方法的使用

(x_train, y_train), (x_test, y_test) = cifar10.load_data()
y_train = np_utils.to_categorical(y_train, num_classes)
y_test = np_utils.to_categorical(y_test, num_classes)

datagen = ImageDataGenerator(
    featurewise_center=True,
    featurewise_std_normalization=True,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True)
model.fit_generator(datagen.flow(x_train,y_train,batch_size=32),
steps_per_epoch=len(x_train) / 32, epochs=epochs)

2、flow_from_directory方法的使用

train_datagen = ImageDataGenerator(
        rescale=1./255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
        'data/train',
        target_size=(150, 150),
        batch_size=32,
        class_mode='binary')
validation_generator = test_datagen.flow_from_directory(
        'data/validation',
        target_size=(150, 150),
        batch_size=32,
        class_mode='binary')
model.fit_generator(
        train_generator,
        steps_per_epoch=2000,
        epochs=50,
        validation_data=validation_generator,
        validation_steps=800)

我们只需要将训练数据或者验证数据按照下面的文件夹格式进行存放即可
tensorflow2.x学习笔记七:ImageDataGenerator的使用_第1张图片
3、flow_from_dataframe方法的使用

##我们需要将数据转换成下面这种形式,列表中放入元组,每个元组里面存放路径和类别,然后再转换成datafram;
[('./cifar10/train/1.png', 'frog'),
 ('./cifar10/train/2.png', 'truck'),
 ('./cifar10/train/3.png', 'truck'),
 ('./cifar10/train/4.png', 'deer'),
 ('./cifar10/train/5.png', 'automobile')]
[('./cifar10/test/1.png', 'cat'),
 ('./cifar10/test/2.png', 'cat'),
 ('./cifar10/test/3.png', 'cat'),
 ('./cifar10/test/4.png', 'cat'),

tensorflow2.x学习笔记七:ImageDataGenerator的使用_第2张图片

train_datagen = keras.preprocessing.image.ImageDataGenerator(
    rescale = 1./255,
    rotation_range = 40,
    width_shift_range = 0.2,
    height_shift_range = 0.2,
    shear_range = 0.2,
    zoom_range = 0.2,
    horizontal_flip = True,
    fill_mode = 'nearest',)
train_generator = train_datagen.flow_from_dataframe(
    train_df,
    ##string, path to the directory to read images from. 
    ##If `None`,data in `x_col` column should be absolute paths.
    directory = './',
    x_col = 'filepath',
    y_col = 'class',
    classes = class_names,
    target_size = (height, width),
    batch_size = batch_size,
    seed = 7,
    shuffle = True,
    class_mode = 'sparse',)

你可能感兴趣的:(tensorflow2.x学习笔记七:ImageDataGenerator的使用)