Keras实践笔记10——使用ImageDataGenerator进行图像增广

from __future__ import print_function

import keras
from keras.datasets import cifar10
from keras.layers import Conv2D, MaxPooling2D
from keras.layers import Dense, Dropout, Flatten
from keras.models import Sequential
from keras.preprocessing.image import ImageDataGenerator

(x_train, y_train), (x_test, y_test) = cifar10.load_data()

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')

x_train /= 255
x_test /= 255

y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)

model = Sequential()
model.add(Conv2D(32, (3, 3), padding='same',
                 input_shape=x_train.shape[1:], activation='relu'))
model.add(Conv2D(32, (3, 3), activation='relu'))
model.add(MaxPooling2D())
model.add(Dropout(0.25))

model.add(Conv2D(64, (3, 3), padding='same', activation='relu'))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D())
model.add(Dropout(0.25))

model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))

model.compile(loss='categorical_crossentropy',
              optimizer=keras.optimizers.rmsprop(lr=0.0001, decay=1e-6),
              metrics=['accuracy'])

datagen = ImageDataGenerator(
    featurewise_center=False,
    samplewise_center=False,
    featurewise_std_normalization=False,
    samplewise_std_normalization=False,
    zca_whitening=False,
    rotation_range=0,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
    vertical_flip=False)

datagen.fit(x_train)

model.fit_generator(datagen.flow(x_train, y_train,
                                 batch_size=32),
                    epochs=100,
                    validation_data=(x_test, y_test),
                    workers=4)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_1 (Conv2D)            (None, 32, 32, 32)        896       
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 30, 30, 32)        9248      
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 15, 15, 32)        0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 15, 15, 32)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 15, 15, 64)        18496     
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 13, 13, 64)        36928     
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 6, 6, 64)          0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 6, 6, 64)          0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 2304)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 512)               1180160   
_________________________________________________________________
dropout_3 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 10)                5130      
=================================================================
Total params: 1,250,858
Trainable params: 1,250,858
Non-trainable params: 0

这次的示例是采样cifar数据集进行的图像分类训练,用的是卷积神经网络,但是与之前的玩法不同,这次使用了ImageDataGenerator 这个工具来做图像增广。与之前最大的区别就是,我们一共fit了2次,第一次是对图像进行增广出来,第二次,也就是真正训练模型的时候,使用的是fit_generator 这个方式

ImageDataGenerator的参数描述及处理效果可以查看这篇文章,非常的好图片数据集太少?看我七十二变,Keras Image Data Augmentation 各参数详解

这个训练还有个特点,就是fit_generator里面,我们传入的是

datagen.flow(x_train, y_train,batch_size=32)

他会以32个作为一个批次,不断的产生新的图片,这样我们就可以在我们原有的少量数据集中,获得非常多的不同处理的图片作为训练了,训练集一下就多了起来。最后在fit_generator的时候,我们指定了用4个线程去跑,加快训练的速度

转载于:https://my.oschina.net/xiaomaijiang/blog/1827510

你可能感兴趣的:(Keras实践笔记10——使用ImageDataGenerator进行图像增广)