keras 分批训练 详解2 - keras 进阶教程

keras 分批训练2

今天讲的是如何使用keras进行分批训练(也叫增量训练增量学习在线训练批量训练)的第二种方法,上一种方法在这里:
https://blog.csdn.net/weixin_42744102/article/details/87272950

上一次讲的是fit_generator的方法,那个方法搞不清每层的名字就很容易报错、计算错了step和batch也很麻烦,需要自己写的生成器和该方法密切耦合,其实不太友好;那么今天我们讲讲另一种方法:train_on_batch

使用方法:

model.train_on_batch(x, y)

使用方法很简单,只需要传入一个batch的data和target就可以,有两个可选参数可以调整权重,一般不用填。
这个方法的好处就是简单易用而且直观,不需要处理fit_genrator的各种step、layer name的问题,不那么强耦合;直观之处就在于,整个方法的作用就是送一个batch的数据进去训练,调用一次就是用batch训练一次,很灵活,个人非常推荐使用。

当然,这个方法需要结合自己实现的一些别的东西,才能完成训练,你需要手动循环epoch次,然后每个循环里面嵌套一个循环,这个便利整个数据集,产生一个一个的batch,并在产生了这些batch之后,调用train_on_batch方法进行训练。这些步骤在fit_genrator中是实现了的,但是它的强耦合导致对开发不太友好,因此还不如自己实现

好了,贴上我一个工程里面的样例代码:

    for epoch in range(EPOCH):

        print('epoch', epoch)

        print(int(data_amo*(1.-VALIDATE_SPLIT)))

        for b_idx in range(0, int(data_amo*(1.-VALIDATE_SPLIT)), BATCH):

            with open('random_data') as f:

                data_gram_sentence = f.readlines()[b_idx:b_idx+BATCH]

            with open('random_target') as f:

                data_target = list(f.readlines())[b_idx:b_idx+BATCH]

            train_x = []

            train_y = []

            for sentence_gram_index in range(len(data_gram_sentence)):

                sentence_gram = data_gram_sentence[sentence_gram_index]

                grams = sentence_gram[:-1].split(' ')

                valid = 0

                sentence_vector = np.zeros(NUM_FEATURES)

                for gram in grams:

                    if gram in model:

                        valid += 1

                        sentence_vector += model[gram]

                if valid != 0:

                    sentence_vector = sentence_vector / valid

                train_x.append(sentence_vector)

                target_single = np.zeros(len(data_li))

                # print('*'*10)

                # print(one_hot_dict[data_target[sentence_gram_index][:-1]])

                target_single[int(one_hot_dict[data_target[sentence_gram_index][:-1]])] = 1.

                # target_single[one_hot_dict[[data_target[sentence_gram_index][:-1]]]] = 1.

                # print(target_single)

                train_y.append(target_single)

            train_x = np.array(train_x)

            train_y = np.array(train_y)

            ks_model.train_on_batch(train_x, train_y, sample_weight=None, class_weight=None)

有不懂或是发现我的疏漏错误的,欢迎随时联系我:[email protected]

下次见~

你可能感兴趣的:(机器学习-技术篇,自然语言处理)