keras生成器(data_generator)批量加载训练数据

1 背景

1、在做深度学习时候,数据量往往非常大,如果一次性加载到内存中,往往会出现OOM问题,为了解决这个问题,我们就不用model.fit()来训练,而用model.fit_generator()来训练。这样来数据就用生成器(lazy calculate)的方式加载数据,减少内存压力。

2、keras的官方文档里并没有实现batch_size的功能

3、同时,看了几篇中文文档,写的data_generator函数都存在batch size > 数据量时候,或者读到文件末尾但是没有达到batch_size的时候,导致生成的数据不正确。

2 如何实现

假设有如下数据:

[
    {
        "sentence": "奖励就是亲",
        "judge": "positive"
    },
    {
        "sentence": "不许这样",
        "judge": "negative"
    },

    {
        "sentence": "他女朋友旁边所以不方便跟说话",
        "judge": "negative"
    },
    {
        "sentence": "真的那好心疼",
        "judge": "negative"
    }
  ]

data_generator 函数如下

 def data_generator(self, file_name, batch_size):
        """
        :return:
        """
        # !!一定要将这几行写到while 循环外面
        train_data_indi = [] 
        train_data_seg = []
        train_data_label = []
        cnt = 0
        while 1:
            label_tags = ['negative', 'positive']
            with open(file_name) as f:
                json_items = ijson.items(f, 'item')
                for json_data in json_items:
                    cnt += 1
                    train_data_label.append([label_tags.index(json_data['judge'])])
                    if cnt == batch_size:
                        print(train_data_label)
                        train_data_label = pad_sequences(train_data_label)
                        train_data_label = to_categorical(train_data_label, 2)
                        yield (train_data_label) #返回数据
                        train_data_indi = []
                        train_data_seg = []
                        train_data_label = []
                        cnt = 0

实验数据

3 实验验证

注:上面代码只截取了一个函数,请自行修改下哈

if __name__ == "__main__":
    test = DataPreprocessor()
    cnt = 0
    for i in test.data_generator('./data/test.json', 5):
    #for i in test_gen('./data/test.json'):
        cnt += 1
        if cnt > 5:
            break

结果如下

Using TensorFlow backend.
[[1], [0], [0], [0], [1]]
[[0], [0], [0], [1], [0]]
[[0], [0], [1], [0], [0]]
[[0], [1], [0], [0], [0]]
[[1], [0], [0], [0], [1]]
[[0], [0], [0], [1], [0]]

4 总结

1、因为fit_generator 调用data_generator时候,需要循环读取文件(可能有多个epoch), 所以需要写一个死循环(while 1: )来轮训读取数据

2、当迭代器迭代到文件末尾数据,但是还到不到batch_size( if cnt == batch_size:),就会从新进入第二轮的while 循环,第二次读取文件,将文件开头的数据放到一个batch数据中,直到满足batch_size 返回数据

3、如果有不正确的地方欢迎指正啊,不想误导他人,也想自己将本质搞清楚。

你可能感兴趣的:(深度学习)