tensorflow.keras.utils.Sequence的使用

tensorflow.keras.utils.Sequence的使用(控制模型从文件读入batch_size的数据

在使用keras的时候,一般使用model.fit()来传入训练数据,fit()接受多种类型的数据:
1.数组类型(如numpy等)。注意,tensorflow2以后的版本在接受h5py类型数据时,容易出错,原因我也不是特别懂
2.dataset类型
3.python generator,但是限制比较多,一般要在编写python generator的平 台下运行模型
4.tensorflow.keras.utils.Sequence,和python generator差不多,但是限制较少,可迁移性更好

第4种类型是本文要讲的重点类型

官方例子

    from skimage.io import imread
    from skimage.transform import resize
    import numpy as np
    import math

    # Here, `x_set` is list of path to the images
    # and `y_set` are the associated classes.

    class CIFAR10Sequence(Sequence):

        def __init__(self, x_set, y_set, batch_size):
            self.x, self.y = x_set, y_set
            self.batch_size = batch_size

        def __len__(self):
            return math.ceil(len(self.x) / self.batch_size)

        def __getitem__(self, idx):
            batch_x = self.x[idx * self.batch_size:(idx + 1) *
            self.batch_size]
            batch_y = self.y[idx * self.batch_size:(idx + 1) *
            self.batch_size]

            return np.array([
                resize(imread(file_name), (200, 200))
                   for file_name in batch_x]), np.array(batch_y)

附上连接:官方例子

下面是我的理解

init():初始化类。
len():返回batch_size的个数,也就是完整跑一遍数据要运行运行模型多少次。
getitem():返回一个batch_size的数据(data,label)
on_epoch_end():这个函数例子中没有用到,但是官网有给,就是在每个 epoch跑完之后,你要做什么可以通过这个函数实现

这是以上函数的作用,虽然官方给的例子是像上面那样的。但是我们却不一定要写和它一模一样的格式,只要每个函数返回的东西和上面例子一样就行(比如:getitem()返回的是一个batch_size的数据,只要你在这个函数返回的是一个batch_size的数据,那么函数里面怎么运行的都可以)。

下面是我自己定义的一个Sequence类,用于从多个h5py文件中读取点云数据

import tensorflow as tf
import math
import h5py
from tensorflow.keras.utils import Sequence


class h5py_file_sequence(Sequence):
    def __init__(self, file_list, batch_size, sampling_num):
        self.open_list = [h5py.File(file) for file in file_list]
        self.batch_size = batch_size
        self.sampling_num = sampling_num

    def __len__(self):
        per_file_batch_num = [math.ceil(f['data'].shape[0] / self.batch_size) for f in self.open_list]
        return tf.reduce_sum(per_file_batch_num)

    def __getitem__(self, idx):
        n = math.ceil(self.open_list[0]['data'].shape[0] / self.batch_size)  # 第一个文件有多少个batch_size大小
        # 这里有一个假设:前面的几个文件含有的数据个数相同,只有最后一个小于等于前面的
        batch_x = self.open_list[idx // n]['data'][(idx % n) * self.batch_size:(idx % n + 1) * self.batch_size, 0:self.sampling_num, :]
        batch_y = self.open_list[idx // n]['label'][(idx % n) * self.batch_size:(idx % n + 1) * self.batch_size]
        return tf.convert_to_tensor(batch_x), tf.convert_to_tensor(batch_y)

    def on_epoch_end(self):
        for f in self.open_list:
            f.close()

可以看到在__getitem__()函数中的返回值,被我强制转换为tensor类型了,这是因为h5py类型的数组数据在fit()进去的时候会出错。关于这一点我在开头已经提到过了。至于为什么对h5py文件不太支持我也不知道。并且tensorflow2的文档里面也没有对h5py类型的数据操作作介绍。

欢迎留言交流

你可能感兴趣的:(深度学习,tensorflow,神经网络,机器学习,数据挖掘)