xlnet中文版本预训练模型终于出来了,见地址https://github.com/ymcui/Chinese-PreTrained-XLNet ,出来之后尝试了下中文文本分类模型,xlnet模型相比bert有很多东西做了改变,模型层面的不多说,目前放出来的中文文本分类模型是采用24层的网络结果,和中文版的bert12层的网络大了两倍,之前论文出来时候有很多,主要是中文数据处理的问题,模型采用的sentencepiece做分词,pad方式采用的是post-padding方式,模型输入输入是len*batch的形式,还有一些segment_ids和mask和普通的模型并不一样,下面直接看代码把,
数据转化为tfrecord:
import tensorflow as tf
import sys
import six
import unicodedata
import sentencepiece as spm
import collections
from textclass import FLAGS
SEG_ID_A = 0
SEG_ID_B = 1
SEG_ID_CLS = 2
SEG_ID_SEP = 3
SEG_ID_PAD = 4
special_symbols = {
"" : 0,
"" : 1,
"" : 2,
"" : 3,
"" : 4,
"" : 5,
"" : 6,
"" : 7,
"" : 8,
}
VOCAB_SIZE = 32000
UNK_ID = special_symbols[""]
CLS_ID = special_symbols[""]
SEP_ID = special_symbols[""]
MASK_ID = special_symbols[""]
EOD_ID = special_symbols[""]
sp = spm.SentencePieceProcessor()
sp.Load(FLAGS.spiece_model_file)
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
def get_class_ids(text,max_seq_length,tokenize_fn):
texts = tokenize_fn(text)
if len(texts) > max_seq_length - 2:
texts = texts[:max_seq_length - 2]
tokens = []
segment_ids = []
for token in texts:
tokens.append(token)
segment_ids.append(SEG_ID_A)
tokens.append(SEP_ID)
segment_ids.append(SEG_ID_A)
tokens.append(CLS_ID)
segment_ids.append(SEG_ID_CLS)
input_ids = tokens
input_mask = [0] * len(input_ids)
if len(input_ids) < max_seq_length:
delta_len = max_seq_length - len(input_ids)
input_ids = [0] * delta_len + input_ids
input_mask = [1] * delta_len + input_mask
segment_ids = [SEG_ID_PAD] * delta_len + segment_ids
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
return input_ids,input_mask,segment_ids
def get_pair_ids(text_a,text_b,max_seq_length,tokenize_fn):
tokens_a = tokenize_fn(text_a)
tokens_b = tokenize_fn(text_b)
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
tokens = []
segment_ids = []
for token in tokens_a:
tokens.append(token)
segment_ids.append(SEG_ID_A)
tokens.append(SEP_ID)
segment_ids.append(SEG_ID_A)
for token in tokens_b:
tokens.append(token)
segment_ids.append(SEG_ID_B)
tokens.append(SEP_ID)
segment_ids.append(SEG_ID_B)
tokens.append(CLS_ID)
segment_ids.append(SEG_ID_CLS)
input_ids = tokens
input_mask = [0] * len(input_ids)
if len(input_ids) < max_seq_length:
delta_len = max_seq_length - len(input_ids)
input_ids = [0] * delta_len + input_ids
input_mask = [1] * delta_len + input_mask
segment_ids = [SEG_ID_PAD] * delta_len + segment_ids
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
return input_ids,input_mask,segment_ids
SPIECE_UNDERLINE = '▁'
def encode_pieces(sp_model, text, return_unicode=True, sample=False):
if six.PY2 and isinstance(text, unicode):
text = text.encode('utf-8')
if not sample:
pieces = sp_model.EncodeAsPieces(text)
else:
pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1)
new_pieces = []
for piece in pieces:
if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit():
cur_pieces = sp_model.EncodeAsPieces(
piece[:-1].replace(SPIECE_UNDERLINE, ''))
if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
if len(cur_pieces[0]) == 1:
cur_pieces = cur_pieces[1:]
else:
cur_pieces[0] = cur_pieces[0][1:]
cur_pieces.append(piece[-1])
new_pieces.extend(cur_pieces)
else:
new_pieces.append(piece)
# note(zhiliny): convert back to unicode for py2
if six.PY2 and return_unicode:
ret_pieces = []
for piece in new_pieces:
if isinstance(piece, str):
piece = piece.decode('utf-8')
ret_pieces.append(piece)
new_pieces = ret_pieces
return new_pieces
def encode_ids(sp_model, text, sample=False):
pieces = encode_pieces(sp_model, text, return_unicode=False, sample=sample)
ids = [sp_model.PieceToId(piece) for piece in pieces]
return ids
def preprocess_text(inputs, lower=False, remove_space=True, keep_accents=False):
if remove_space:
outputs = ' '.join(inputs.strip().split())
else:
outputs = inputs
outputs = outputs.replace("``", '"').replace("''", '"')
if six.PY2 and isinstance(outputs, str):
outputs = outputs.decode('utf-8')
if not keep_accents:
outputs = unicodedata.normalize('NFKD', outputs)
outputs = ''.join([c for c in outputs if not unicodedata.combining(c)])
if lower:
outputs = outputs.lower()
return outputs
def tokenize_fn(text):
text = preprocess_text(text, lower=True)
return encode_ids(sp, text)
def get_vocab(path):
maps = collections.defaultdict()
i = 0
with tf.gfile.GFile(path, "r") as f:
for line in f.readlines():
maps[line.strip()] = i
i = i + 1
f.close()
return maps
def writedataclass(inputpath, vocab, outputpath,max_seq_length,tokenize_fn):
eachonum = 5000
num = 0
recordfilenum = 0
ftrecordfilename = ("xlnetreading.tfrecords-%.3d" % recordfilenum)
writer = tf.python_io.TFRecordWriter(outputpath + ftrecordfilename)
with open(inputpath) as f:
for text in f.readlines():
texts = text.split("\t")
content= texts[0].lower().strip()
label = vocab.get(texts[1].strip())
num = num + 1
input_ids,input_mask,segment_ids=get_class_ids(content, max_seq_length, tokenize_fn)
if num > eachonum:
num = 1
recordfilenum = recordfilenum + 1
ftrecordfilename = ("xlnetreading.tfrecords-%.3d" % recordfilenum)
writer = tf.python_io.TFRecordWriter(outputpath + ftrecordfilename)
example = tf.train.Example(
features=tf.train.Features(
feature={'input_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=input_ids)),
'input_mask': tf.train.Feature(int64_list=tf.train.Int64List(value=input_mask)),
'segment_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=segment_ids)),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
}))
serialized = example.SerializeToString()
writer.write(serialized)
writer.close()
f.close()
自己写了一个文本分类的类,看下:
class XlnetReadingClass(object):
def __init__(self,model_config_path,is_training,FLAGS,input_ids,segment_ids,
input_mask,label,n_class):
self.xlnet_config = xlnet.XLNetConfig(json_path=model_config_path)
self.run_config = xlnet.create_run_config(is_training, True, FLAGS)
self.input_ids=tf.transpose(input_ids,[1,0])
self.segment_ids = tf.transpose(segment_ids, [1, 0])
self.input_mask = tf.transpose(input_mask, [1, 0])
self.model = xlnet.XLNetModel(
xlnet_config=self.xlnet_config,
run_config=self.run_config,
input_ids=self.input_ids,
seg_ids=self.segment_ids,
input_mask=self.input_mask)
cls_scope = FLAGS.cls_scope
summary = self.model.get_pooled_out(FLAGS.summary_type, FLAGS.use_summ_proj)
self.per_example_loss, self.logits = modeling.classification_loss(
hidden=summary,
labels=label,
n_class=n_class,
initializer=self.model.get_initializer(),
scope=cls_scope,
return_logits=True)
self.total_loss = tf.reduce_mean(self.per_example_loss)
with tf.name_scope("train_op"):
self.train_op, _, _ = model_utils.get_train_op(FLAGS, self.total_loss)
with tf.name_scope("acc"):
one_hot_target = tf.one_hot(label, n_class)
self.acc=self.accuracy(self.logits,one_hot_target)
def accuracy(self,logits, labels):
arglabels_ = tf.argmax(tf.nn.softmax(logits), 1)
arglabels = tf.argmax(tf.squeeze(labels), 1)
acc = tf.to_float(tf.equal(arglabels_, arglabels))
return tf.reduce_mean(acc)
def main(_):
print('Loading config...')
n_class = 38
input_path = FLAGS.data_dir + "xlnetreading.tfrecords*"
print("input_path:", input_path)
files = tf.train.match_filenames_once(input_path)
"""
inputs是你数据的输入路径
"""
input_ids, input_mask, segment_ids, label_ids = inputs(files, batch_size=FLAGS.batch_size, num_epochs=5,max_seq_length=FLAGS.max_seq_length)
model_config_path=FLAGS.model_config_path
is_training=False
init_checkpoint = FLAGS.init_checkpoint
model = XlnetReadingClass(model_config_path, is_training,FLAGS, input_ids
, segment_ids,input_mask, label_ids, n_class)
tvars = tf.trainable_variables()
if init_checkpoint:
(assignment_map, initialized_variable_names) = model_utils.get_assignment_map_from_checkpoint(tvars,
init_checkpoint)
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
print("restore sucess on cpu or gpu")
session = tf.Session()
session.run(tf.global_variables_initializer())
session.run(tf.local_variables_initializer())
print("**** Trainable Variables ****")
for var in tvars:
if var.name in initialized_variable_names:
init_string = ", *INIT_FROM_CKPT*"
print("name ={0}, shape = {1}{2}".format(var.name, var.shape,
init_string))
print("xlnet reading class model will start train .........")
print(session.run(files))
saver = tf.train.Saver()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord, sess=session)
start_time = time.time()
for i in range(8000):
_, loss_train, acc = session.run([model.train_op, model.total_loss, model.acc])
if i % 100 == 0:
end_time = time.time()
time_dif = end_time - start_time
time_dif = timedelta(seconds=int(round(time_dif)))
msg = 'Iter: {0:>6}, Train Loss: {1:>6.2},' \
+ ' Cost: {2} Time:{3} acc:{4}'
print(msg.format(i, loss_train, time_dif, datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), acc))
start_time = time.time()
if i % 500 == 0 and i > 0:
saver.save(session, "../exp/reading/model.ckpt", global_step=i)
coord.request_stop()
coord.join(threads)
session.close()