keras.utils.Sequence生成数据

keras.utils.Sequence

昨天写了用multiprocessing.Pool的多进程加载数据和yield生成数据送入model.fit_generator里面训练→multiprocessing.Pool。今天试着用keras.utils.Sequence基类构建一个数据生成器,其数据加载的速度和用multiprocessing.Pool差不多。
Sequence是进行多进程处理的更安全的方法。这种结构保证网络在每个时期每个样本只训练一次,这与生成器不同。每一个 Sequence 必须实现 getitemlen 方法。
如果你想在迭代之间修改你的数据集,你可以实现 on_epoch_end。
getitem 方法应该范围一个完整的批次。
使用keras.utils.Sequence处理数据可以调用model.fit_generator里面的use_multiprocessing=True, workers=Numprocess,而上一篇文章的数据加载方法不能调用。
下面基于图像分类任务用keras.utils.Sequence 加载数据feed到model.fit_generator中训练的代码,可以上一篇文章中给的代码,进行速度对比一下

import keras
import cv2
import numpy as np


class Dataloader(keras.utils.Sequence):
    def __init__(self, filepath, batchsize, img_shape=(224, 224)):
        self.lines = open(filepath).readlines()
        self.batchsize = batchsize
        self.img_shape = img_shape
        self.labelDict = {。。。}    
        self.labelNum = len(self.labelDict)

    def __len__(self):
        return int(np.ceil(len(self.lines) / self.batchsize))

    #一个batch的数据处理,返回需要feed到模型中训练的数据
    def __getitem__(self, idx):
        images = []
        labels = []
        for content in self.lines[idx * self.batchsize:(idx + 1) * self.batchsize]:
            image, label = self.get_random_data(content, self.img_shape, self.labelDict, self.labelNum)
            images.append(image)
            labels.append(label)

        return np.array(images), np.array(labels)

    def on_epoch_end(self):
        if shuffle == True:
            random.shuffle(self.lines)

    #对图像做随机数据增强
    def get_random_data(self, content, img_shape, labelDict, labelNum):
        def add_noise(image, percentage):
            noise_image = image.copy()
            im_w = image.shape[1]
            im_h = image.shape[0]
            noise_num = int(percentage * im_w * im_h)
            for i in range(noise_num):
                temp_x = np.random.randint(0, image.shape[1])
                temp_y = np.random.randint(0, image.shape[0])
                noise_image[temp_y][temp_x][np.random.randint(3)] = np.random.randn(1)[0]
            return noise_image

        def rotate(image):
            (h, w) = image.shape[:2]
            center = (w / 2, h / 2)
            angle = (np.random.random() - 0.5) * 20
            M = cv2.getRotationMatrix2D(center, angle, 1)
            image = cv2.warpAffine(image, M, (w, h))
            return image

        def crop(image):
            img_w = image.shape[1]
            img_h = image.shape[0]
            h = np.random.randint(30, 50)
            w = np.random.randint(30, 50)
            image = image[h:h + img_h, w:w + img_w, :]
            return image

        NUM_ANGMENTATION_SUPPORT = 4

        # 数据格式   imgPath,label
        imgPath, label = content.strip().split(",")
        image = cv2.imread(imgPath)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        img_h, img_w, _ = image.shape
        aug_num = np.random.randint(low=0, high=NUM_ANGMENTATION_SUPPORT)
        aug_queue = np.random.permutation(NUM_ANGMENTATION_SUPPORT)[0:aug_num]
        for idx in aug_queue:
            if idx == 0:
                image = np.fliplr(image)
            elif idx == 1:
                image = crop(image)
            elif idx == 2:
                image = add_noise(image, 0.25)
            elif idx == 3:
                image = rotate(image)
        image = (image - 127) * 0.0078125
        image = cv2.resize(image, img_shape)
        label_index = labelDict[label]
        label = np.zeros(labelNum)
        label[label_index] = 1  # [1,0,0,.....]

        return image, label

你可能感兴趣的:(keras)