Tensorflow keras.preprocessing.image.ImageDataGenerator 自定义图像数据集 (三)

主要以Keras的ImageDataGenerator图像生成器方式讲解如何定义数据集

关于生成器函数的介绍,请参考:https://blog.csdn.net/Forrest97/article/details/106317598
优点:针对大样本的图像数据集,生成器函数可以节约内存资源,在一组epoch中不会出现重复的step数据(待考证)

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

import os
import numpy as np
import matplotlib.pyplot as plt

加载一个猫狗的二分类图像数据集到本地

_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')
print(PATH)

目录下,一级目录以及划分成train和validation两个目录
Tensorflow keras.preprocessing.image.ImageDataGenerator 自定义图像数据集 (三)_第1张图片
在train目录下是猫狗两类的对应的目录
Tensorflow keras.preprocessing.image.ImageDataGenerator 自定义图像数据集 (三)_第2张图片
生成训练和验证图像生成器对象

train_image_generator = ImageDataGenerator(rescale=1./255) 
validation_image_generator = ImageDataGenerator(rescale=1./255) 

其中可以对训练集进行数据增强操作

train_image_generator = ImageDataGenerator(rescale=1./255
						rotation_range=30,
						width_shift_range=0,2,
						height_shift_range=0.2,
						shear_range=0.2,
						zoom_range=0.2,
						horizontal_flip=True,
						)  

配置数据集生成器参数

batch_size = 128
epochs = 15
IMG_HEIGHT = 150
IMG_WIDTH = 150
train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')
train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,
                                                           directory=train_dir,
                                                           shuffle=True,
                                                           target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                           class_mode='binary')
                                                           
val_data_gen = validation_image_generator.flow_from_directory(batch_size=batch_size,
                                                              directory=validation_dir,
                                                              target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                              class_mode='binary')

使用迭代器next()查看训练数据集

sample_training_images, _ = next(train_data_gen)
def plotImages(images_arr):
    fig, axes = plt.subplots(1, 5, figsize=(20,20))
    axes = axes.flatten()
    for img, ax in zip( images_arr, axes):
        ax.imshow(img)
        ax.axis('off')
    plt.tight_layout()
    plt.show()
plotImages(sample_training_images[:5])

使用fit_generator方法训练网络

history = model.fit_generator(
    train_data_gen,
    steps_per_epoch=total_train // batch_size,
    epochs=epochs,
    validation_data=val_data_gen,
    validation_steps=total_val // batch_size
)

由于train_data_gen和validation_data都是生成器函数,每次训练的时候不停生成批量的训练和测试数据集。steps_per_epoch表示从生成器中抽取total_train // batch_size轮样本后,结束训练停止生成器继续生成数据。

你可能感兴趣的:(Tensorflow,Python,卷积神经网络)