本文将介绍三个使用BERT编码句子(从BERT中提取向量)的工具。
(1)Embedding-as-service
github
这个库类似于bert-as-service,可以将句子编码成固定长度的向量,目前支持的预训练模型有BERT、ALBERT、XLNet、ELMO、Golve、word2vec等,我们可以将其作为我们模型的一部分,也可以将其作为一个服务直接使用它编码我们的句子。下面给除一个其作为服务编码句子的示例:
>>> from embedding_as_service_client import EmbeddingClient
>>> en = EmbeddingClient(host=<host_server_ip>, port=<host_port>)
>>> vecs = en.encode(texts=['hello aman', 'how are you?'])
>>> vecs
array([[[ 1.7049843 , 0. , 1.3486509 , ..., -1.3647075 ,
0.6958289 , 1.8013777 ], ... [ 0.4913215 , 0.60877025, 0.73050433, ..., -0.64490885, 0.8525057 , 0.3080206 ]]], dtype=float32)
>>> vecs.shape
(2, 128, 768) # batch x max_sequence_length x embedding_size
详细内容请点击上方给出的github链接。
(2)BERT预训练模型字向量提取工具
本工具直接读取BERT预训练模型,从中提取样本文件中所有使用到字向量,保存成向量文件,为后续模型提供embdding。
本工具直接读取预训练模型,不需要其它的依赖,同时把样本中所有 出现的字符对应的字向量全部提取,后续的模型可以非常快速进行embdding
github完整源码
#!/usr/bin/env python
# coding: utf-8
__author__ = 'xmxoxo'
'''
BERT预训练模型字向量提取工具
版本: v 0.3.2
更新: 2020/3/25 11:11
git: https://github.com/xmxoxo/BERT-Vector/
'''
import argparse
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
import numpy as np
import os
import sys
import traceback
import pickle
gblVersion = '0.3.2'
# 如果模型的文件名不同,可修改此处
model_name = 'bert_model.ckpt'
vocab_name = 'vocab.txt'
# BERT embdding提取类
class bert_embdding():
def __init__(self, model_path='', fmt='pkl'):
# 模型和词表的文件名
ckpt_path = os.path.join(model_path, model_name)
vocab_file = os.path.join(model_path, vocab_name)
if not os.path.isfile(vocab_file):
print('词表文件不存在,请检查...')
#sys.exit()
return
# 从模型读出指定层
reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)
#param_dict = reader.get_variable_to_shape_map()
self.emb = reader.get_tensor("bert/embeddings/word_embeddings")
self.vocab = open(vocab_file,'r', encoding='utf-8').read().split("\n")
print('embeddings size: %s' % str(self.emb.shape))
print('词表大小:%d' % len(self.vocab))
# 兼容不同格式
self.fmt=fmt
# 取出指定字符的embdding,返回向量
def get_embdding (self, char):
if char in self.vocab:
index = self.vocab.index(char)
return self.emb[index,:]
else:
return None
# 根据字符串提取向量并保存到文件
def export (self, txt_all, out_file=''):
# 过滤重复,形成字典
txt_lst = sorted(list(set(txt_all)))
print('文本字典长度:%d, 正在提取字向量...' % len(txt_lst))
count = 0
# 可选择输出哪种格式 2020/3/25
if self.fmt=='pkl':
print('正在保存为pkl格式文件...')
# 使用字典存储,使用时更加方便。 2020/3/23
lst_vector = dict()
for word in txt_lst:
v = self.get_embdding(word)
if not (v is None):
count += 1
lst_vector[word] = v
# 改为使用pickle保存文件 2020/3/23
with open(out_file, 'wb') as out:
pickle.dump(lst_vector, out, 2)
if self.fmt=='txt':
print('正在保存为txt格式文件...')
with open(out_file, 'w', encoding='utf-8') as out:
for word in txt_lst:
v = self.get_embdding(word)
if not (v is None):
count += 1
out.write(word + " " + " ".join([str(i) for i in v])+"\n")
print('字向量共提取:%d个' % count)
# get all files and floders in a path
# fileExt: ['png','jpg','jpeg']
# return:
# return a list ,include floders and files , like [['./aa'],['./aa/abc.txt']]
@staticmethod
def getFiles (workpath, fileExt = []):
try:
lstFiles = []
lstFloders = []
if os.path.isdir(workpath):
for dirname in os.listdir(workpath) :
file_path = os.path.join(workpath, dirname)
if os.path.isfile(file_path):
if fileExt:
if dirname[dirname.rfind('.')+1:] in fileExt:
lstFiles.append (file_path)
else:
lstFiles.append (file_path)
if os.path.isdir( file_path ):
lstFloders.append (file_path)
elif os.path.isfile(workpath):
lstFiles.append(workpath)
else:
return None
lstRet = [lstFloders,lstFiles]
return lstRet
except Exception as e :
return None
# 增加批量处理目录下的某类文件 v 0.1.2 xmxoxo 2020/3/23
def export_path (self, path, ext=['csv','txt'], out_file=''):
try:
files = self.getFiles(path,ext)
# 合并数据
txt_all = []
tmp = ''
for fn in files[1]:
print('正在读取数据文件:%s' % fn)
with open(fn, 'r', encoding='utf-8') as f:
tmp = f.read()
txt_all += list(set(tmp))
txt_all = list(set(txt_all))
self.export(txt_all, out_file=out_file)
except Exception as e :
print('批量处理出错:')
print('Error in get_randstr: '+ traceback.format_exc())
return None
# 命令行
def main_cli ():
parser = argparse.ArgumentParser(description='BERT模型字向量提取工具')
parser.add_argument('-v', '--version', action='version', version='%(prog)s ' + gblVersion )
parser.add_argument('--model_path', default='', required=True, type=str, help='BERT预训练模型的目录')
parser.add_argument('--in_file', default='', required=True, type=str, help='待提取的文件名或者目录名')
parser.add_argument('--out_file', default='./bert_embedding.pkl', type=str, help='输出文件名')
parser.add_argument('--ext', default=['csv','txt'], type=str, nargs='+', help='指定目录时读取的数据文件扩展名')
parser.add_argument('--fmt', default='pkl', type=str, help='输出文件的格式,可设置txt或者pkl, 默认为pkl')
args = parser.parse_args()
# 预训练模型的目录
model_path = args.model_path
# 输出文件名
out_file = args.out_file
# 包含所有文本的内容
in_file = args.in_file
# 指定的扩展名
ext = args.ext
# 文件格式
fmt = args.fmt
if not fmt in ['pkl', 'txt']:
fmt='pkl'
if fmt=='txt' and out_file[-4:]=='.pkl':
out_file = out_file[:-3] + 'txt'
if not os.path.isdir(model_path):
print('模型目录不存在,请检查:%s' % model_path)
sys.exit()
if not (os.path.isfile(in_file) or os.path.isdir(in_file)):
print('数据文件不存在,请检查:%s' % in_file)
sys.exit()
print('\nBERT 字向量提取工具 V' + gblVersion )
print('-'*40)
bertemb = bert_embdding(model_path=model_path, fmt=fmt)
# 针对文件和目录分别处理 2020/3/23 by xmxoxo
if os.path.isfile(in_file):
txt_all = open(in_file,'r', encoding='utf-8').read()
bertemb.export(txt_all, out_file=out_file)
if os.path.isdir(in_file):
bertemb.export_path(in_file, ext=ext, out_file=out_file)
if __name__ == '__main__':
pass
main_cli()
(3)使用BERT编码句子
本文将BERT进行了封装,我们可以直接输入句子并得到句子对应的向量。
如下所示:
from bert_encoder import BertEncoder
be = BertEncoder()
embedding = be.encode("新年快乐,恭喜发财,万事如意!")
print(embedding)
print(embedding.shape)
完整封装:
完整代码
# -*- coding:utf-8 -*-
import os
from bert import modeling
import tensorflow as tf
from bert import tokenization
flags = tf.flags
FLAGS = flags.FLAGS
bert_path = r'chinese_L-12_H-768_A-12'
root_path = os.getcwd()
flags.DEFINE_string(
"bert_config_file", os.path.join(bert_path, 'bert_config.json'),
"The config json file corresponding to the pre-trained BERT model."
)
flags.DEFINE_string("vocab_file", os.path.join(bert_path, 'vocab.txt'),
"The vocabulary file that the BERT model was trained on.")
flags.DEFINE_bool(
"do_lower_case", True,
"Whether to lower case the input text."
)
flags.DEFINE_integer(
"max_seq_length", 128,
"The maximum total input sequence length after WordPiece tokenization."
)
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
def data_preprocess(sentence):
tokens = []
for i, word in enumerate(sentence):
# 分词,如果是中文,就是分字
token = tokenizer.tokenize(word)
tokens.extend(token)
# 序列截断
if len(tokens) >= FLAGS.max_seq_length - 1:
tokens = tokens[0:(FLAGS.max_seq_length - 2)] # -2 的原因是因为序列需要加一个句首和句尾标志
ntokens = []
segment_ids = []
ntokens.append("[CLS]") # 句子开始设置CLS 标志
segment_ids.append(0)
# append("O") or append("[CLS]") not sure!
for i, token in enumerate(tokens):
ntokens.append(token)
segment_ids.append(0)
ntokens.append("[SEP]") # 句尾添加[SEP] 标志
segment_ids.append(0)
# append("O") or append("[SEP]") not sure!
input_ids = tokenizer.convert_tokens_to_ids(ntokens) # 将序列中的字(ntokens)转化为ID形式
# print(input_ids)
input_mask = [1] * len(input_ids)
# print(input_mask)
while len(input_ids) < FLAGS.max_seq_length:
input_ids.append(0)
input_mask.append(0)
input_ids = [input_ids]
return input_ids, input_mask
class BertEncoder(object):
def __init__(self):
self.bert_model = modeling.BertModel(config=bert_config, is_training=False, max_seq_length=FLAGS.max_seq_length)
tvars = tf.trainable_variables()
(assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, FLAGS.init_cheeckpoint)
tf.train.init_from_checkpoint(FLAGS.init_cheeckpoint, assignment_map)
self.sess = tf.Session()
self.sess.run(tf.global_variables_initializer())
def encode(self, sentence):
input_ids, input_mask = data_preprocess(sentence)
return self.sess.run(self.bert_model.embedding_output, feed_dict={
self.bert_model.input_ids:input_ids})
if __name__ == "__main__":
be = BertEncoder()
embedding = be.encode("新年快乐,恭喜发财,万事如意!")
print(embedding)
print(embedding.shape)
参考:
https://github.com/xmxoxo/BERT-Vector
https://github.com/lzphahaha/bert_encoder
https://github.com/amansrivastava17/embedding-as-service#-supported-embeddings-and-models