bert-事件实体抽取-with_Predict

加入了predict 的代码,还可以提升准确率,没有解决对于 由于 bert 编码问题带来的 对应不到原句

最大为512,数据都筛选过了

import tensorflow as tf
import numpy as np
from bert import modeling
from bert import tokenization
from bert import optimization
import os


flags = tf.flags
FLAGS = flags.FLAGS

flags.DEFINE_integer('train_batch_size',6,'define the train batch size')
flags.DEFINE_integer('num_train_epochs',3,'define the num train epochs')
flags.DEFINE_float('warmup_proportion',0.1,'define the warmup proportion')
flags.DEFINE_float('learning_rate',5e-5,'the initial learning rate for adam')
flags.DEFINE_bool('is_training',True,'define weather fine-tune the bert model')
flags.DEFINE_integer('max_sentence_len',512,'define the max len of sentence')
flags.DEFINE_bool('task_train',True,'define the train task')
flags.DEFINE_bool('task_predict',True,'define the predict task')


def get_start_end_index(text,subtext):
    for i in range(len(text)):
        if text[i:i+len(subtext)] == subtext:
            return (i,i+len(subtext)-1)
    return (-1,-1)


train_data = []
with open('data/train_data.txt',encoding='UTF-8') as fp:
    strLines = fp.readlines()
    strLines = [item.strip() for item in strLines]
    strLines = [eval(item) for item in strLines]
    train_data.extend(strLines)

test_data = []
with open('data/test_data.txt',encoding='UTF-8') as fp:
    strLines = fp.readlines()
    strLines = [item.strip() for item in strLines]
    strLines = [eval(item) for item in strLines]
    test_data.extend(strLines)



# config_path = r'D:\NLP_SOUNDAI\learnTensor\package9\bert\chinese_L-12_H-768_A-12\bert_config.json'
# checkpoint_path = r'D:\NLP_SOUNDAI\learnTensor\package9\bert\chinese_L-12_H-768_A-12\bert_model.ckpt'
# dict_path = r'D:\NLP_SOUNDAI\learnTensor\package9\bert\chinese_L-12_H-768_A-12\vocab.txt'
config_path = './bert/chinese_L-12_H-768_A-12/bert_config.json'
checkpoint_path = './bert/chinese_L-12_H-768_A-12/bert_model.ckpt'
dict_path = './bert/chinese_L-12_H-768_A-12/vocab.txt'
bert_config = modeling.BertConfig.from_json_file(config_path)
tokenizer = tokenization.FullTokenizer(vocab_file=dict_path,do_lower_case=True)




def input_str_concat(inputList):
    assert len(inputList) == 2
    t, c = inputList
    newStr = '__%s__%s' % (c, t)
    newStr = newStr[:510]
    tokens = tokenizer.tokenize(newStr)
    tokens = ['[CLS]'] + tokens + ['[SEP]']
    input_ids = tokenizer.convert_tokens_to_ids(tokens)
    input_mask = [1] * len(input_ids)
    segment_ids = [0] * len(input_ids)
    return tokens, (input_ids, input_mask, segment_ids)



def sequence_padding(sequence):
    lenlist = [len(item) for item in sequence]
    maxlen = max(lenlist)
    return np.array([
        np.concatenate([item,[0]*(maxlen - len(item))]) if len(item) < maxlen else item for item in sequence
    ])


# get train data batch
def get_data_batch():
    batch_size = FLAGS.train_batch_size
    epoch = FLAGS.num_train_epochs
    for oneEpoch in range(epoch):
        num_batches = ((len(train_data) -1) // batch_size) + 1
        for i in range(num_batches):
            batch_data = train_data[i*batch_size:(i+1)*batch_size]
            yield_batch_data = {
                'input_ids':[],
                'input_mask':[],
                'segment_ids':[],
                'start_ids':[],
                'end_ids':[]
            }
            for item in batch_data:
                tokens, (input_ids, input_mask, segment_ids) = input_str_concat(item[:-1])
                target_tokens = tokenizer.tokenize(item[2])
                
                start, end = get_start_end_index(tokens, target_tokens)
                
                start_ids = [0] * len(input_ids)
                end_ids = [0] * len(input_ids)
                start_ids[start] = 1
                end_ids[end] = 1
                yield_batch_data['input_ids'].append(input_ids)
                yield_batch_data['input_mask'].append(input_mask)
                yield_batch_data['segment_ids'].append(segment_ids)
                yield_batch_data['start_ids'].append(start_ids)
                yield_batch_data['end_ids'].append(end_ids)
            yield_batch_data['input_ids'] = sequence_padding(yield_batch_data['input_ids'])
            yield_batch_data['input_mask'] = sequence_padding(yield_batch_data['input_mask'])
            yield_batch_data['segment_ids'] = sequence_padding(yield_batch_data['segment_ids'])
            yield_batch_data['start_ids'] = sequence_padding(yield_batch_data['start_ids'])
            yield_batch_data['end_ids'] = sequence_padding(yield_batch_data['end_ids'])
            yield yield_batch_data


with tf.Graph().as_default(),tf.Session() as sess:
    input_ids_p = tf.placeholder(dtype=tf.int64,shape=[None,None],name='input_ids_p')
    input_mask_p = tf.placeholder(dtype=tf.int64,shape=[None,None],name='input_mask_p')
    segment_ids_p = tf.placeholder(dtype=tf.int64,shape=[None,None],name='segment_ids_p')
    start_p = tf.placeholder(dtype=tf.int64,shape=[None,None],name='start_p')
    end_p = tf.placeholder(dtype=tf.int64,shape=[None,None],name='end_p')


    model = modeling.BertModel(config=bert_config,
                               is_training=FLAGS.is_training,
                               input_ids=input_ids_p,
                               input_mask=input_mask_p,
                               token_type_ids=segment_ids_p,
                               use_one_hot_embeddings=False)
    output_layer = model.get_sequence_output()

   
    word_dim = output_layer.get_shape().as_list()[-1]
    output_reshape = tf.reshape(output_layer,shape=[-1,word_dim],name='output_reshape')

    with tf.variable_scope('weitht_and_bias',reuse=tf.AUTO_REUSE,initializer=tf.truncated_normal_initializer(mean=0.,stddev=0.05)):
        weight_start = tf.get_variable(name='weight_start',shape=[word_dim,1])
        bias_start = tf.get_variable(name='bias_start',shape=[1])
        weight_end = tf.get_variable(name='weight_end',shape=[word_dim,1])
        bias_end = tf.get_variable(name='bias_end',shape=[1])

    with tf.name_scope('predict_start_and_end'):
        pred_start = tf.einsum('ijk,kd->ijd', output_layer, weight_start)
        pred_start = tf.nn.bias_add(pred_start, bias_start)
        pred_start = tf.squeeze(pred_start, -1)

        pred_end = tf.einsum('ijk,kd->ijd', output_layer, weight_end)
        pred_end = tf.nn.bias_add(pred_end, bias_end)
        pred_end = tf.squeeze(pred_end, -1)

        pred_start_index = tf.argmax(pred_start,axis=1)
        pred_end_index = tf.argmax(pred_end,axis=1)

    with tf.name_scope('loss'):
        loss1 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=pred_start,labels=start_p))

        cumsum_start_p = 1 - tf.cumsum(start_p, axis=1)
        cumsum_start_p_10 = cumsum_start_p * 100000
        end_p -= cumsum_start_p_10

        loss2 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=pred_end,labels=end_p))
        loss = loss1 + loss2

    with tf.name_scope('acc_predict'):
        start_acc_bool = tf.equal(tf.argmax(start_p, axis=1), tf.argmax(pred_start, axis=1))
        end_acc_bool = tf.equal(tf.argmax(end_p, axis=1), tf.argmax(pred_end, axis=1))
        start_acc = tf.reduce_mean(tf.cast(start_acc_bool, dtype=tf.float32))
        end_acc = tf.reduce_mean(tf.cast(end_acc_bool, dtype=tf.float32))
        total_acc = tf.reduce_mean(tf.cast(tf.reduce_all([start_acc_bool, end_acc_bool], axis=0), dtype=tf.float32))

    with tf.name_scope('train_op'):
        num_train_steps = int(
            len(train_data) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
        train_op = optimization.create_optimizer(
                loss, FLAGS.learning_rate, num_train_steps, num_warmup_steps, use_tpu=False)



    tvars = tf.trainable_variables()
    (assignment_map,initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars,checkpoint_path)
    tf.train.init_from_checkpoint(checkpoint_path,assignment_map)
    sess.run(tf.variables_initializer(tf.global_variables()))

    if FLAGS.task_train:
        total_steps = 0
        for yield_batch_data in get_data_batch():
            total_steps += 1
            feed_dict = {
                input_ids_p: yield_batch_data['input_ids'],
                input_mask_p: yield_batch_data['input_mask'],
                segment_ids_p: yield_batch_data['segment_ids'],
                start_p: yield_batch_data['start_ids'],
                end_p: yield_batch_data['end_ids']
            }
            fetches = [train_op, loss, start_acc, end_acc, total_acc]

            _, loss_val, start_acc_val, end_acc_val, total_acc_val = sess.run(fetches, feed_dict=feed_dict)
            print('i : %s, loss : %s, start_acc : %s, end_acc : %s, total_acc : %s' % (
            total_steps, loss_val, start_acc_val, end_acc_val, total_acc_val))
        print('train task done ...')

    if FLAGS.task_predict:
        resultList = []
        for item in test_data:
            if item[1] == '其他':
                resultList.append('NaN')
            else:
                yield_batch_data = {
                    'input_ids': [],
                    'input_mask': [],
                    'segment_ids': [],
                }

                tokens, (input_ids, input_mask, segment_ids) = input_str_concat(item)

                yield_batch_data['input_ids'].append(input_ids)
                yield_batch_data['input_mask'].append(input_mask)
                yield_batch_data['segment_ids'].append(segment_ids)

                yield_batch_data['input_ids'] = sequence_padding(yield_batch_data['input_ids'])
                yield_batch_data['input_mask'] = sequence_padding(yield_batch_data['input_mask'])
                yield_batch_data['segment_ids'] = sequence_padding(yield_batch_data['segment_ids'])

                feed_dict = {
                    input_ids_p: yield_batch_data['input_ids'],
                    input_mask_p: yield_batch_data['input_mask'],
                    segment_ids_p: yield_batch_data['segment_ids']
                }

                fetches = [pred_start_index, pred_end_index]

                start_index,end_index = sess.run(fetches,feed_dict=feed_dict)
                start = start_index[0]
                end   = end_index[0]
                oneResult = tokens[start:end+1]

                if oneResult in item[0]:
                    resultList.append(oneResult)
                else:
                    if oneResult.upper() in item[0]:
                        resultList.append(oneResult.upper())
                    else:
                        originStr = item[0]
                        originStr = originStr.upper()
                        oneResult = oneResult.upper()

                        oneResult.replace('[UNK]','').replace('#','').strip()
                        resultList.append(oneResult)

        with open('result.txt',encoding='UTF-8',mode='a') as ff:
            ff.write('\n'.join(resultList)+'\n')





 

你可能感兴趣的:(tensorflow学习笔记,自然语言处理)