flyai文本分类2


# -*- coding: utf-8 -*-
import os
import argparse
import tensorflow as tf
import keras
from keras.models import Model
from keras.layers import *
from keras.optimizers import RMSprop
from sklearn.model_selection import train_test_split, KFold
from flyai.framework import FlyAI
from flyai.data_helper import DataHelper
from path import MODEL_PATH, DATA_PATH
import pandas as pd
from keras.optimizers import Adam
import numpy as np
from data_helper import load_dict, load_labeldict, get_batches, read_data, get_val_batch
from keras_bert import load_trained_model_from_checkpoint,Tokenizer
import codecs
from keras.callbacks import *
from keras.metrics import *
from keras.utils import to_categorical
# 必须使用该方法下载模型,然后加载
from flyai.utils import remote_helper
path = remote_helper.get_remote_data('https://www.flyai.com/m/chinese_L-12_H-768_A-12.zip')
# 预训练好的模型
config_path = os.path.join(DATA_PATH, 'model/chinese_L-12_H-768_A-12/bert_config.json')
checkpoint_path = os.path.join(DATA_PATH, 'model/chinese_L-12_H-768_A-12/bert_model.ckpt')
dict_path = os.path.join(DATA_PATH, 'model/chinese_L-12_H-768_A-12/vocab.txt')
# 计算top-k正确率,当预测值的前k个值中存在目标类别即认为预测正确
def acc_top2(y_true, y_pred):
    return top_k_categorical_accuracy(y_true, y_pred, k=2)

'''
此项目为FlyAI2.0新版本框架,数据读取,评估方式与之前不同
2.0框架不再限制数据如何读取
样例代码仅供参考学习,可以自己修改实现逻辑。
模版项目下载支持 PyTorch、Tensorflow、Keras、MXNET、scikit-learn等机器学习框架
第一次使用请看项目中的:FlyAI2.0竞赛框架使用说明.html
使用FlyAI提供的预训练模型可查看:https://www.flyai.com/models
学习资料可查看文档中心:https://doc.flyai.com/
常见问题:https://doc.flyai.com/question.html
遇到问题不要着急,添加小姐姐微信,扫描项目里面的:FlyAI小助手二维码-小姐姐在线解答您的问题.png
'''

# 项目的超参,不使用可以删除
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--EPOCHS", default=10, type=int, help="train epochs")
parser.add_argument("-b", "--BATCH", default=32, type=int, help="batch size")
args = parser.parse_args()


#让每条文本的长度相同,用0填充
def seq_padding(X, padding=0):
    L = [len(x) for x in X]
    ML = max(L)
    return np.array([
        np.concatenate([x, [padding] * (ML - len(x))]) if len(x) < ML else x for x in X
    ])
# 将词表中的词编号转换为字典
token_dict = {}
with codecs.open(dict_path, 'r', 'utf8') as reader:
    for line in reader:
        token = line.strip()
        token_dict[token] = len(token_dict)

# 重写tokenizer
class OurTokenizer(Tokenizer):
    def _tokenize(self, text):
        R = []
        for c in text:
            if c in self._token_dict:
                R.append(c)
            elif self._is_space(c):
                R.append('[unused1]')  # 用[unused1]来表示空格类字符
            else:
                R.append('[UNK]')  # 不在列表的字符用[UNK]表示
        return R

tokenizer = OurTokenizer(token_dict)

# data_generator只是一种为了节约内存的数据方式
class data_generator:
    def __init__(self, data, batch_size=32, shuffle=True):
        self.data = data
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.steps = len(self.data) // self.batch_size
        if len(self.data) % self.batch_size != 0:
            self.steps += 1

    def __len__(self):
        return self.steps

    def __iter__(self):
        while True:
            idxs = list(range(len(self.data)))

            if self.shuffle:
                np.random.shuffle(idxs)

            X1, X2, Y = [], [], []
            for i in idxs:
                d = self.data[i]
                text = d[0][:100]
                x1, x2 = tokenizer.encode(first=text)
                y = d[1]
                X1.append(x1)
                X2.append(x2)
                Y.append([y])
                if len(X1) == self.batch_size or i == idxs[-1]:
                    X1 = seq_padding(X1)
                    X2 = seq_padding(X2)
                    Y = seq_padding(Y)
                    yield [X1, X2], Y[:, 0, :]


class Main(FlyAI):
    '''
    项目中必须继承FlyAI类,否则线上运行会报错。
    '''

    def download_data(self):
        # 下载数据
        data_helper = DataHelper()
        data_helper.download_from_ids("MedicalClass")
        print('=*=数据下载完成=*=')

    def deal_with_data(self):
        '''
        处理数据,没有可不写。
        :return:
        '''
        # 加载数据
        self.data = pd.read_csv(os.path.join(DATA_PATH, 'MedicalClass/train.csv'))
        # 划分训练集、测试集
        self.label2id, _ = load_labeldict(os.path.join(DATA_PATH, 'MedicalClass/label.dict'))
        self.train_set = []
        for data_row in self.data.iloc[:].itertuples():
            self.train_set.append((data_row.text, to_categorical(self.label2id[data_row.label],len(self.label2id))))
        self.train_set=np.array(self.train_set)
        print('=*=数据处理完成=*=')

    def train(self):
        kf = KFold(n_splits=2, shuffle=True, random_state=520).split(self.train_set)
        for i, (train_fold, test_fold) in enumerate(kf):
            X_train, X_valid, = self.train_set[train_fold, :], self.train_set[test_fold, :]

            bert_model = load_trained_model_from_checkpoint(config_path, checkpoint_path, seq_len=None)  # 加载预训练模型
            for l in bert_model.layers:
                l.trainable = True
            x1_in = Input(shape=(None,))
            x2_in = Input(shape=(None,))
            x = bert_model([x1_in, x2_in])
            x = Lambda(lambda x: x[:, 0])(x)  # 取出[CLS]对应的向量用来做分类
            p = Dense(len(self.label2id), activation='softmax')(x)
            model = Model([x1_in, x2_in], p)
            model.compile(loss='categorical_crossentropy',
                          optimizer=Adam(1e-5),  # 用足够小的学习率
                          metrics=['accuracy', acc_top2])
            # model = build_bert(len(self.label2id))

            early_stopping = EarlyStopping(monitor='val_acc', patience=3)  # 早停法,防止过拟合
            plateau = ReduceLROnPlateau(monitor="val_acc", verbose=1, mode='max', factor=0.5,
                                        patience=2)  # 当评价指标不在提升时,减少学习率
            checkpoint = ModelCheckpoint(os.path.join(MODEL_PATH, 'model.h5'), monitor='val_acc', verbose=2,
                                         save_best_only=True, mode='max')  # 保存最好的模型

            train_D = data_generator(X_train, shuffle=True)
            valid_D = data_generator(X_valid, shuffle=True)
            # 模型训练
            model.fit_generator(
                train_D.__iter__(),
                steps_per_epoch=len(train_D),
                epochs=5,
                validation_data=valid_D.__iter__(),
                validation_steps=len(valid_D),
                callbacks=[early_stopping, plateau, checkpoint],
            )


if __name__ == '__main__':
    main = Main()
    main.download_data()
    main.deal_with_data()
    main.train()

    exit(0)

 

你可能感兴趣的:(flyai文本分类2)