主要以Keras的ImageDataGenerator
图像生成器方式讲解如何定义数据集
关于生成器函数的介绍,请参考:https://blog.csdn.net/Forrest97/article/details/106317598
优点:针对大样本的图像数据集,生成器函数可以节约内存资源,在一组epoch中不会出现重复的step数据(待考证)
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import os
import numpy as np
import matplotlib.pyplot as plt
加载一个猫狗的二分类图像数据集到本地
_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')
print(PATH)
目录下,一级目录以及划分成train和validation两个目录
在train目录下是猫狗两类的对应的目录
生成训练和验证图像生成器对象
train_image_generator = ImageDataGenerator(rescale=1./255)
validation_image_generator = ImageDataGenerator(rescale=1./255)
其中可以对训练集进行数据增强操作
train_image_generator = ImageDataGenerator(rescale=1./255
rotation_range=30,
width_shift_range=0,2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
)
配置数据集生成器参数
batch_size = 128
epochs = 15
IMG_HEIGHT = 150
IMG_WIDTH = 150
train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')
train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,
directory=train_dir,
shuffle=True,
target_size=(IMG_HEIGHT, IMG_WIDTH),
class_mode='binary')
val_data_gen = validation_image_generator.flow_from_directory(batch_size=batch_size,
directory=validation_dir,
target_size=(IMG_HEIGHT, IMG_WIDTH),
class_mode='binary')
使用迭代器next()查看训练数据集
sample_training_images, _ = next(train_data_gen)
def plotImages(images_arr):
fig, axes = plt.subplots(1, 5, figsize=(20,20))
axes = axes.flatten()
for img, ax in zip( images_arr, axes):
ax.imshow(img)
ax.axis('off')
plt.tight_layout()
plt.show()
plotImages(sample_training_images[:5])
使用fit_generator方法训练网络
history = model.fit_generator(
train_data_gen,
steps_per_epoch=total_train // batch_size,
epochs=epochs,
validation_data=val_data_gen,
validation_steps=total_val // batch_size
)
由于train_data_gen和validation_data都是生成器函数,每次训练的时候不停生成批量的训练和测试数据集。steps_per_epoch
表示从生成器中抽取total_train // batch_size轮样本后,结束训练停止生成器继续生成数据。