用Keras写出像PyTorch一样的DataLoader方法

点击上方“小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

     数据导入、网络构建和模型训练永远是深度学习代码的主要模块。笔者此前曾写过PyTorch数据导入的pipeline标准结构总结PyTorch数据Pipeline标准化代码模板,本文参考PyTorch的DataLoader,给Keras也总结一套自定义的DataLoader框架。

用Keras写出像PyTorch一样的DataLoader方法_第1张图片

Keras常规用法

     按照正常人使用Keras的方法,大概就像如下代码一样:

import numpy as np
from keras.models import Sequential
# 导入全部数据
X, y = np.load('some_training_set_with_labels.npy')
# Design model
model = Sequential()
[...] # 网络结构
model.compile()
# 模型训练
model.fit(x=X, y=y)

     虽然一次性导入训练数据一定程度上能够提高训练速度,但随着数据量增多,这种将数据一次性读入内存的方法很容易造成显存溢出的问题。所以,在开启一个深度学习项目时,一个较为明智的做法就是分批次读取训练数据。

数据存放方式

     常规情况下,我们的训练数据要么是按照分类和阶段有组织的存放在硬盘目录下(多见于比赛和标准数据集),要么以csv格式将数据路径和对应标签给出(多见于深度学习项目情形)。

用Keras写出像PyTorch一样的DataLoader方法_第2张图片

数据按照类别和使用阶段存放(kaggle猫狗分类数据集)

用Keras写出像PyTorch一样的DataLoader方法_第3张图片

数据按照csv文件形式给出(花朵分类数据集)

ImageDataGenerator

     Keras早就考虑到了按批次导入数据的需求,所以ImageDataGenerator模块提供了按批次导入的数据生成器方法,包括数据增强和分批训练等方法。如下所示,分别对训练集和验证集调用ImageDataGenerator函数,然后从目录下按批次导入。

from tensorflow.keras.preprocessing.image import ImageDataGenerator
# 数据增强
train_datagen = ImageDataGenerator(
      rescale=1./255,
      shear_range=0.2,
      zoom_range=0.2,
      horizontal_flip=True)


test_datagen = ImageDataGenerator(rescale=1./255)
# 从目录下按批次读取
train_generator = train_datagen.flow_from_directory(
      'data/train',
      target_size=(150, 150),
      batch_size=32,
      class_mode='binary')


validation_generator = test_datagen.flow_from_directory(
      'data/validation',
      target_size=(150, 150),
      batch_size=32,
      class_mode='binary')

最后对模型调用fit_generator方法进行训练:

model.fit_generator(
      train_generator,
      steps_per_epoch=2000,
      epochs=50,
      validation_data=validation_generator,
      validation_steps=800)

     以上Keras提供的数据生成器的方法读入数据虽然好,但还不够灵活,实际深度学习项目会碰到各种不同的数据存放情况,根据实际情况来自定义一套类似于PyTorch的DataLoader非常有必要。

Keras Sequence

     Keras Sequence方法用于拟合一个数据序列,每一个Sequence必须提供__getitem__和__len__方法,这跟Torch的Dataset模块类似。Sequence是进行多进程处理的更安全的方法,这种结构保证网络在每个时期每个样本只训练一次,这与生成器不同。使用示例如下:

from skimage.io import imread
from skimage.transform import resize 
import numpy as np 
from keras.utils import Sequence


# x_set是图像的路径列表 
# y_set是对应的类别
class CIFAR10Sequence(Sequence): 
    def __init__(self, x_set, y_set, batch_size): 
        self.x, self.y = x_set, y_set 
        self.batch_size = batch_size 


    def __len__(self): 
        return int(np.ceil(len(self.x) / float(self.batch_size))) 


    def __getitem__(self, idx): 
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size] 
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size] 
        return np.array([ resize(imread(file_name), (200, 200)) for file_name in batch_x]), np.array(batch_y)

Torch风格的Keras DataLoader

     现在我们针对一个13分类的多标签图像分类问题来自定义Torch风格的DataLoader。数据以csv的形式存放图片路径和对应标签,具体如下:

用Keras写出像PyTorch一样的DataLoader方法_第4张图片

     可以看到,每张图像都有至少一个、至多三个的动物标签。所以标签在处理的时候需要进行转化。首先定义继承Sequence的DataGenerator类和一些初始化方法。

class DataGenerator(Sequence):
    """
    基于Sequence的自定义Keras数据生成器
    """
    def __init__(self, df, list_IDs,
                 to_fit=True, batch_size=8, dim=(256, 472),
                 n_channels=3, n_classes=13, shuffle=True):
        """ 初始化方法
        :param df: 存放数据路径和标签的数据框
        :param list_IDs: 数据索引列表
        :param to_fit: 设定是否返回标签y
        :param batch_size: batch size 
        :param dim: 图像大小
        :param n_channels: 图像通道
        :param n_classes: 标签类别
        :param shuffle: 每一个epoch后是否打乱数据
        """
        self.df = df
        self.list_IDs = list_IDs
        self.to_fit = to_fit
        self.batch_size = batch_size
        self.dim = dim
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.on_epoch_end()

     然后定义on_epoch_end方法来在每个epoch之后shuffle数据,以及底层数据读取和标签编码方法。

def on_epoch_end(self):
    """每个epoch之后更新索引
    """
    self.indexes = np.arange(len(self.list_IDs))
    if self.shuffle == True:
        np.random.shuffle(self.indexes)

     图像读取方法:

def _load_image(self, image_path):
    """cv2读取图像
    """
    # img = cv2.imread(image_path)
    img = cv2.imdecode(np.fromfile(image_path, dtype=np.uint8), cv2.IMREAD_COLOR)
    w, h, _ = img.shape
    if w>h:
        img = np.rot90(img)
    img = cv2.resize(img, (472, 256))
    return img

     标签编码转换方法:

def _labels_encode(self, s, keys):
    """标签one-hot编码转换
    """
    cs = s.split('_')
    y = np.zeros(13)
    for i in range(len(cs)):
        for j in range(len(keys)):
            for c in cs:
                if c == keys[j]:
                    y[j] = 1
    return y

     然后定义每个批次生成图片和标签的方法:

def _generate_X(self, list_IDs_temp):
    """生成每一批次的图像
    :param list_IDs_temp: 批次数据索引列表
    :return: 一个批次的图像
    """
    # 初始化
    X = np.empty((self.batch_size, *self.dim, self.n_channels))
    # 生成数据
    for i, ID in enumerate(list_IDs_temp):
        # 存储一个批次
        X[i,] = self._load_image(self.df.iloc[ID].images)
    return X


def _generate_y(self, list_IDs_temp):
    """生成每一批次的标签
    :param list_IDs_temp: 批次数据索引列表
    :return: 一个批次的标签
    """
    y = np.empty((self.batch_size, self.n_classes), dtype=int)
    # Generate data
    for i, ID in enumerate(list_IDs_temp):
        # Store sample
        y[i,] = self._labels_encode(self.df.iloc[ID].labels, config.LABELS)
    return y

     底层读取和生成方法定义完成后,即可定义__getitem__和__len__方法:

def __getitem__(self, index):
    """生成每一批次训练数据
    :param index: 批次索引
    :return: 训练图像和标签
    """
    # 生成批次索引
    indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
    # 索引列表
    list_IDs_temp = [self.list_IDs[k] for k in indexes]
    # 生成数据
    X = self._generate_X(list_IDs_temp)
    if self.to_fit:
        y = self._generate_y(list_IDs_temp)
        return X, y
    else:
        return X
        
def __len__(self):
    """每个epoch下的批次数量
    """
    return int(np.floor(len(self.list_IDs) / self.batch_size))

    完整的Keras DataLoader代码如下:

class DataGenerator(Sequence):
    """
    基于Sequence的自定义Keras数据生成器
    """
    def __init__(self, df, list_IDs,
                 to_fit=True, batch_size=8, dim=(256, 472),
                 n_channels=3, n_classes=13, shuffle=True):
        """ 初始化方法
        :param df: 存放数据路径和标签的数据框
        :param list_IDs: 数据索引列表
        :param to_fit: 设定是否返回标签y
        :param batch_size: batch size 
        :param dim: 图像大小
        :param n_channels: 图像通道
        :param n_classes: 标签类别
        :param shuffle: 每一个epoch后是否打乱数据
        """
        self.df = df
        self.list_IDs = list_IDs
        self.to_fit = to_fit
        self.batch_size = batch_size
        self.dim = dim
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.on_epoch_end()
        
   def __getitem__(self, index):
        """生成每一批次训练数据
        :param index: 批次索引
        :return: 训练图像和标签
        """
        # 生成批次索引
        indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
        # 索引列表
        list_IDs_temp = [self.list_IDs[k] for k in indexes]
        # 生成数据
        X = self._generate_X(list_IDs_temp)
        if self.to_fit:
            y = self._generate_y(list_IDs_temp)
            return X, y
        else:
            return X
        
    def __len__(self):
        """每个epoch下的批次数量
        """
        return int(np.floor(len(self.list_IDs) / self.batch_size))
        
        def _generate_X(self, list_IDs_temp):
        """生成每一批次的图像
        :param list_IDs_temp: 批次数据索引列表
        :return: 一个批次的图像
        """
        # 初始化
        X = np.empty((self.batch_size, *self.dim, self.n_channels))
        # 生成数据
        for i, ID in enumerate(list_IDs_temp):
            # 存储一个批次
            X[i,] = self._load_image(self.df.iloc[ID].images)
        return X


    def _generate_y(self, list_IDs_temp):
        """生成每一批次的标签
        :param list_IDs_temp: 批次数据索引列表
        :return: 一个批次的标签
        """
        y = np.empty((self.batch_size, self.n_classes), dtype=int)
        # Generate data
        for i, ID in enumerate(list_IDs_temp):
            # Store sample
            y[i,] = self._labels_encode(self.df.iloc[ID].labels, config.LABELS)
        return y
        
     def on_epoch_end(self):
        """每个epoch之后更新索引
        """
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)
            
     def _load_image(self, image_path):
        """cv2读取图像
        """
        # img = cv2.imread(image_path)
        img = cv2.imdecode(np.fromfile(image_path, dtype=np.uint8), cv2.IMREAD_COLOR)
        w, h, _ = img.shape
        if w>h:
            img = np.rot90(img)
        img = cv2.resize(img, (472, 256))
        return img
        
     def _labels_encode(self, s, keys):
        """标签one-hot编码转换
        """
        cs = s.split('_')
        y = np.zeros(13)
        for i in range(len(cs)):
            for j in range(len(keys)):
                for c in cs:
                    if c == keys[j]:
                        y[j] = 1
        return y

     使用效果如下(打印每一批次输入输出的shape):

用Keras写出像PyTorch一样的DataLoader方法_第5张图片

     实际训练时,我们可以大致编写如下训练代码框架:

import numpy as np
from keras.models import Sequential
import DataGenerator
# Parameters
params = {'batch_size': 64,
          'n_classes': 6,
          'n_channels': 1,
          'shuffle': True}
# Generators
training_generator = DataGenerator(train_df, train_idx, **params)
validation_generator = DataGenerator(val_df, val_idx, **params)


# Design model
model = Sequential()
[...] # Architecture
model.compile()


# Train model on dataset
model.fit_generator(generator=training_generator,
                    validation_data=validation_generator,
                    use_multiprocessing=True,
                    workers=4)


     以上就是本文主要内容。本文提供的Keras DataLoader方法仅供参考使用,自定义Keras DataLoader还应根据具体数据组织形式来灵活决定。

  参考资料:

https://towardsdatascience.com/keras-data-generators-and-how-to-use-them-b69129ed779c

小白团队出品:零基础精通语义分割↓

用Keras写出像PyTorch一样的DataLoader方法_第6张图片

下载1:OpenCV-Contrib扩展模块中文版教程

在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲

在「小白学视觉」公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲

在「小白学视觉」公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群

欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~

用Keras写出像PyTorch一样的DataLoader方法_第7张图片

用Keras写出像PyTorch一样的DataLoader方法_第8张图片

你可能感兴趣的:(python,人工智能,深度学习,机器学习,opencv)