keras.utils.Sequence()
用于拟合数据序列的基对象,例如一个数据集。
每一个 Sequence 必须实现 getitem 和 len 方法。 如果你想在迭代之间修改你的数据集,你可以实现 on_epoch_end。 getitem 方法应该返回一个完整的批次。
Sequence 是进行多进程处理的更安全的方法。这种结构保证网络在每个时期每个样本只训练一次,这与生成器不同。
以上是keras官方手册的话,我的理解是Sequence应该是一个与数据迭代有关的类。
import numpy as np
import pandas as pd
import sklearn
import albumentations
from albumentations import (Blur,Flip,ShiftScaleRotate,GridDistortion,ElasticTransform,
HueSaturationValue,Transpose,RandomBrightnessContrast,CLAHE,
CoarseDropout,Normalize,ToFloat,OneOf,Compose)
import keras
import keras.backend as K
class MyGenerator(keras.utils.Sequence):
def __init__(self, image_filenames, labels, root_directory='',
batch_size=128, mix=False,
shuffle=True, augment=True):
self.image_filenames = image_filenames
self.labels = labels
self.root_directory = root_directory
self.batch_size = batch_size
self.is_mix = mix
self.is_augment = augment
self.shuffle = shuffle
if self.shuffle:
self.on_epoch_end()
if self.is_augment:
self.generator = Compose([Blur(),Flip(),Transpose(),ShiftScaleRotate(),
RandomBrightnessContrast(),HueSaturationValue(),
CLAHE(),GridDistortion(),ElasticTransform(),CoarseDropout(),
ToFloat(max_value=255.0,p=1.0)],p=1.0)
else:
self.generator = Compose([ToFloat(max_value=255.0,p=1.0)],p=1.0)
def __len__(self):
return int(np.ceil(len(self.image_filenames)/self.batch_size))
def on_epoch_end(self):
if self.shuffle:
self.image_filenames, self.labels = sklearn.utils.shuffle(self.image_filenames,self.labels)
def mix_up(self,x,y):
original_index = np.arange(x.shape[0])
new_index = np.arange(x.shape[0])
np.random.shuffle(new_index)
beta = np.random.beta(0.2, 0.4)
mix_x = beta * x[original_index] + (1 - beta) * x[new_index]
mix_y = beta * y[original_index] + (1 - beta) * y[new_index]
return mix_x,mix_y
def __getitem__(self,index):
batch_x = self.image_filenames[index*self.batch_size:(index+1)*self.batch_size]
batch_y = self.labels[index*self.batch_size:(index+1)*self.batch_size]
new_images = []
new_labels = []
for image_name,label in zip(batch_x,batch_y):
image = cv2.imread(os.path.join(self.root_directory,image_name))
image = cv2.resize(image,(300,300))
img = self.generator(image=image)['image']
new_images.append(img)
new_labels.append(label)
new_images = np.array(new_images)
new_labels = np.array(new_labels)
if self.is_mix:
new_images, new_labels = self.mix_up(new_images, new_labels)
return new_images,new_labels
def init(self, image_filenames, labels, root_directory=’’, batch_size=128, mix=False, shuffle=True, augment=True)
该函数初始化代码所需的参数,需要注意的是我在这里使用albumentations库根据是否使用数据增强参数定义了数据增强器。如果是则该数据增强器将会随机对输入的图像数据进行模糊,翻转,网格失真,弹性变换,色调、饱和度、值变化,亮度、对比度变化,对比度受限自适应直方图均衡化,在图像上生成矩形区域,最后进行归一化除以255.0;如果否,则只进行归一化。
image_filenames文件名(需要包含文件后缀)
labels标签,
root_directory=’'根目录,
batch_size=128批次大小,
mix=False是否混合,
shuffle=True是否打乱顺序,
augment=True是否进行数据增强
def on_epoch_end(self)
该函数会在一轮数据训练完后被调用,我们在这里是实现打乱数据顺序
def len(self)
该函数返回一轮数据训练完需要的迭代次数
def mix_up(self,x,y)
该函数是实现实现混合的,既随机将数据进行组合,我在这里是将图像数据按比例相加。
def getitem(self,index)
该函数会在每次数据迭代是被调用,index是在训练一轮数据时,其是第几次迭代,通过它我们可获得这次迭代所需的数据。
该数主要实现图像数据读取,数据增强,数据混合,并返回增强后的数据和标签。
batch_size = 50
#训练集数据增强器
train_generator = MyGenerator(train_x, train_y, '../input',
batch_size=batch_size)
#带混合的训练集数据增强器
train_mixup = MyGenerator(train_x, train_y, '../input',
batch_size=batch_size,mix=True)
#验证集数据增强器
valid_generator = MyGenerator(valid_x, valid_y, '../input',
batch_size=batch_size,augment=False,shuffle=False)