keras 使用Albumentations库自定义数据增强器

keras 使用Albumentations库自定义数据增强器

  • Sequence
  • 自定义数据增强器

Sequence

keras.utils.Sequence()
用于拟合数据序列的基对象,例如一个数据集。
每一个 Sequence 必须实现 getitemlen 方法。 如果你想在迭代之间修改你的数据集,你可以实现 on_epoch_end。 getitem 方法应该返回一个完整的批次。
Sequence 是进行多进程处理的更安全的方法。这种结构保证网络在每个时期每个样本只训练一次,这与生成器不同。

以上是keras官方手册的话,我的理解是Sequence应该是一个与数据迭代有关的类。

自定义数据增强器

import numpy as np
import pandas as pd
import sklearn
import albumentations
from albumentations import (Blur,Flip,ShiftScaleRotate,GridDistortion,ElasticTransform,
                            HueSaturationValue,Transpose,RandomBrightnessContrast,CLAHE,
                            CoarseDropout,Normalize,ToFloat,OneOf,Compose)
import keras
import keras.backend as K


class MyGenerator(keras.utils.Sequence):
    def __init__(self, image_filenames, labels, root_directory='',
                 batch_size=128, mix=False,
                 shuffle=True, augment=True):
        self.image_filenames = image_filenames
        self.labels = labels
        self.root_directory = root_directory
        self.batch_size = batch_size
        self.is_mix = mix
        self.is_augment = augment
        self.shuffle = shuffle 
        if self.shuffle:
            self.on_epoch_end()
        if self.is_augment:
            self.generator = Compose([Blur(),Flip(),Transpose(),ShiftScaleRotate(),
                                  RandomBrightnessContrast(),HueSaturationValue(),
                                 CLAHE(),GridDistortion(),ElasticTransform(),CoarseDropout(),
                                 ToFloat(max_value=255.0,p=1.0)],p=1.0)
        else:
            self.generator = Compose([ToFloat(max_value=255.0,p=1.0)],p=1.0)
    def __len__(self):
        return int(np.ceil(len(self.image_filenames)/self.batch_size))
    
    def on_epoch_end(self):
        if self.shuffle:
            self.image_filenames, self.labels = sklearn.utils.shuffle(self.image_filenames,self.labels)
    def mix_up(self,x,y):
        original_index = np.arange(x.shape[0])
        new_index = np.arange(x.shape[0])
        np.random.shuffle(new_index)
        beta = np.random.beta(0.2, 0.4)
        mix_x = beta * x[original_index] + (1 - beta) * x[new_index]
        mix_y = beta * y[original_index] + (1 - beta) * y[new_index]
        return mix_x,mix_y
        
    def __getitem__(self,index):
        batch_x = self.image_filenames[index*self.batch_size:(index+1)*self.batch_size]
        batch_y = self.labels[index*self.batch_size:(index+1)*self.batch_size]
        new_images = []
        new_labels = []
        for image_name,label in zip(batch_x,batch_y):
            image = cv2.imread(os.path.join(self.root_directory,image_name))
            image = cv2.resize(image,(300,300))
            img = self.generator(image=image)['image']
            new_images.append(img)
            new_labels.append(label)
        new_images = np.array(new_images)
        new_labels = np.array(new_labels)
        if self.is_mix:
            new_images, new_labels = self.mix_up(new_images, new_labels)
        return new_images,new_labels     

def init(self, image_filenames, labels, root_directory=’’, batch_size=128, mix=False, shuffle=True, augment=True)
该函数初始化代码所需的参数,需要注意的是我在这里使用albumentations库根据是否使用数据增强参数定义了数据增强器。如果是则该数据增强器将会随机对输入的图像数据进行模糊,翻转,网格失真,弹性变换,色调、饱和度、值变化,亮度、对比度变化,对比度受限自适应直方图均衡化,在图像上生成矩形区域,最后进行归一化除以255.0;如果否,则只进行归一化。
image_filenames文件名(需要包含文件后缀)
labels标签,
root_directory=’'根目录,
batch_size=128批次大小,
mix=False是否混合,
shuffle=True是否打乱顺序,
augment=True是否进行数据增强

def on_epoch_end(self)
该函数会在一轮数据训练完后被调用,我们在这里是实现打乱数据顺序

def len(self)
该函数返回一轮数据训练完需要的迭代次数

def mix_up(self,x,y)
该函数是实现实现混合的,既随机将数据进行组合,我在这里是将图像数据按比例相加。

def getitem(self,index)
该函数会在每次数据迭代是被调用,index是在训练一轮数据时,其是第几次迭代,通过它我们可获得这次迭代所需的数据。
该数主要实现图像数据读取,数据增强,数据混合,并返回增强后的数据和标签。

batch_size = 50
#训练集数据增强器
train_generator = MyGenerator(train_x, train_y, '../input',
                             batch_size=batch_size)
#带混合的训练集数据增强器
train_mixup = MyGenerator(train_x, train_y, '../input', 
                          batch_size=batch_size,mix=True)
#验证集数据增强器
valid_generator = MyGenerator(valid_x, valid_y, '../input',
                              batch_size=batch_size,augment=False,shuffle=False)

你可能感兴趣的:(keras 使用Albumentations库自定义数据增强器)