# -*- coding: utf-8 -*-
import os
import argparse
import tensorflow as tf
import keras
from keras.models import Model
from keras.layers import *
from keras.optimizers import RMSprop
from sklearn.model_selection import train_test_split, KFold
from flyai.framework import FlyAI
from flyai.data_helper import DataHelper
from path import MODEL_PATH, DATA_PATH
import pandas as pd
from keras.optimizers import Adam
import numpy as np
from data_helper import load_dict, load_labeldict, get_batches, read_data, get_val_batch
from keras_bert import load_trained_model_from_checkpoint,Tokenizer
import codecs
from keras.callbacks import *
from keras.metrics import *
from keras.utils import to_categorical
# 必须使用该方法下载模型,然后加载
from flyai.utils import remote_helper
path = remote_helper.get_remote_data('https://www.flyai.com/m/chinese_L-12_H-768_A-12.zip')
# 预训练好的模型
config_path = os.path.join(DATA_PATH, 'model/chinese_L-12_H-768_A-12/bert_config.json')
checkpoint_path = os.path.join(DATA_PATH, 'model/chinese_L-12_H-768_A-12/bert_model.ckpt')
dict_path = os.path.join(DATA_PATH, 'model/chinese_L-12_H-768_A-12/vocab.txt')
# 计算top-k正确率,当预测值的前k个值中存在目标类别即认为预测正确
def acc_top2(y_true, y_pred):
return top_k_categorical_accuracy(y_true, y_pred, k=2)
'''
此项目为FlyAI2.0新版本框架,数据读取,评估方式与之前不同
2.0框架不再限制数据如何读取
样例代码仅供参考学习,可以自己修改实现逻辑。
模版项目下载支持 PyTorch、Tensorflow、Keras、MXNET、scikit-learn等机器学习框架
第一次使用请看项目中的:FlyAI2.0竞赛框架使用说明.html
使用FlyAI提供的预训练模型可查看:https://www.flyai.com/models
学习资料可查看文档中心:https://doc.flyai.com/
常见问题:https://doc.flyai.com/question.html
遇到问题不要着急,添加小姐姐微信,扫描项目里面的:FlyAI小助手二维码-小姐姐在线解答您的问题.png
'''
# 项目的超参,不使用可以删除
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--EPOCHS", default=10, type=int, help="train epochs")
parser.add_argument("-b", "--BATCH", default=32, type=int, help="batch size")
args = parser.parse_args()
#让每条文本的长度相同,用0填充
def seq_padding(X, padding=0):
L = [len(x) for x in X]
ML = max(L)
return np.array([
np.concatenate([x, [padding] * (ML - len(x))]) if len(x) < ML else x for x in X
])
# 将词表中的词编号转换为字典
token_dict = {}
with codecs.open(dict_path, 'r', 'utf8') as reader:
for line in reader:
token = line.strip()
token_dict[token] = len(token_dict)
# 重写tokenizer
class OurTokenizer(Tokenizer):
def _tokenize(self, text):
R = []
for c in text:
if c in self._token_dict:
R.append(c)
elif self._is_space(c):
R.append('[unused1]') # 用[unused1]来表示空格类字符
else:
R.append('[UNK]') # 不在列表的字符用[UNK]表示
return R
tokenizer = OurTokenizer(token_dict)
# data_generator只是一种为了节约内存的数据方式
class data_generator:
def __init__(self, data, batch_size=32, shuffle=True):
self.data = data
self.batch_size = batch_size
self.shuffle = shuffle
self.steps = len(self.data) // self.batch_size
if len(self.data) % self.batch_size != 0:
self.steps += 1
def __len__(self):
return self.steps
def __iter__(self):
while True:
idxs = list(range(len(self.data)))
if self.shuffle:
np.random.shuffle(idxs)
X1, X2, Y = [], [], []
for i in idxs:
d = self.data[i]
text = d[0][:100]
x1, x2 = tokenizer.encode(first=text)
y = d[1]
X1.append(x1)
X2.append(x2)
Y.append([y])
if len(X1) == self.batch_size or i == idxs[-1]:
X1 = seq_padding(X1)
X2 = seq_padding(X2)
Y = seq_padding(Y)
yield [X1, X2], Y[:, 0, :]
class Main(FlyAI):
'''
项目中必须继承FlyAI类,否则线上运行会报错。
'''
def download_data(self):
# 下载数据
data_helper = DataHelper()
data_helper.download_from_ids("MedicalClass")
print('=*=数据下载完成=*=')
def deal_with_data(self):
'''
处理数据,没有可不写。
:return:
'''
# 加载数据
self.data = pd.read_csv(os.path.join(DATA_PATH, 'MedicalClass/train.csv'))
# 划分训练集、测试集
self.label2id, _ = load_labeldict(os.path.join(DATA_PATH, 'MedicalClass/label.dict'))
self.train_set = []
for data_row in self.data.iloc[:].itertuples():
self.train_set.append((data_row.text, to_categorical(self.label2id[data_row.label],len(self.label2id))))
self.train_set=np.array(self.train_set)
print('=*=数据处理完成=*=')
def train(self):
kf = KFold(n_splits=2, shuffle=True, random_state=520).split(self.train_set)
for i, (train_fold, test_fold) in enumerate(kf):
X_train, X_valid, = self.train_set[train_fold, :], self.train_set[test_fold, :]
bert_model = load_trained_model_from_checkpoint(config_path, checkpoint_path, seq_len=None) # 加载预训练模型
for l in bert_model.layers:
l.trainable = True
x1_in = Input(shape=(None,))
x2_in = Input(shape=(None,))
x = bert_model([x1_in, x2_in])
x = Lambda(lambda x: x[:, 0])(x) # 取出[CLS]对应的向量用来做分类
p = Dense(len(self.label2id), activation='softmax')(x)
model = Model([x1_in, x2_in], p)
model.compile(loss='categorical_crossentropy',
optimizer=Adam(1e-5), # 用足够小的学习率
metrics=['accuracy', acc_top2])
# model = build_bert(len(self.label2id))
early_stopping = EarlyStopping(monitor='val_acc', patience=3) # 早停法,防止过拟合
plateau = ReduceLROnPlateau(monitor="val_acc", verbose=1, mode='max', factor=0.5,
patience=2) # 当评价指标不在提升时,减少学习率
checkpoint = ModelCheckpoint(os.path.join(MODEL_PATH, 'model.h5'), monitor='val_acc', verbose=2,
save_best_only=True, mode='max') # 保存最好的模型
train_D = data_generator(X_train, shuffle=True)
valid_D = data_generator(X_valid, shuffle=True)
# 模型训练
model.fit_generator(
train_D.__iter__(),
steps_per_epoch=len(train_D),
epochs=5,
validation_data=valid_D.__iter__(),
validation_steps=len(valid_D),
callbacks=[early_stopping, plateau, checkpoint],
)
if __name__ == '__main__':
main = Main()
main.download_data()
main.deal_with_data()
main.train()
exit(0)