在笔者的上一篇文章大白话讲懂word2vec原理和如何使用中提到了如何将词语转变成计算机能够识别的语言,即将文本数据转换成计算机能够运算的数字或者向量这个概念,并详细阐述了word2vec这个模型的原理,如何在gensim框架下使用word2vec将单词转变成一个能够表达单词特征的向量。但是在gensim框架下使用word2vec模型提取特征向量之前,需要准备一些场景中的语料,同事也需要对语料进行分词操作,然后再输入到模型中进行训练,最后才可以用训练好的模型进行特征提取。整个流程还是较为繁琐,同时模型效果取决于语料覆盖度是否全面和训练超参数的选择。其实呢,nlp也有一些现成的文本特征提取的模型,比如前些年比较火爆的google推出的bert模型。本篇文章就简单的介绍一下bert模型,以及新手如何快速上手使用bert提取文本特征,以便去做后续的nlp任务。
bert是由Google的雅各布·德夫林和同事在2018年创建的,并发表了论文Pre-training of Deep Bidirectional Transformers for Language Understanding,bert是”Bidirectional Encoder Representations from Transformers”的首字母缩, 直译就是基于变换器的双向编码器表示技术。
最初的英语bert发布时提供两种类型的预训练模型:(1)BERTBASE模型,一个12层,768维,12个自注意头(self attention head),110M参数的神经网络结构;(2)BERTLARGE模型,一个24层,1024维,16个自注意头,340M参数的神经网络结构。两者的训练语料都是BooksCorpus以及英语维基百科语料,单词量分别是8亿以及25亿。
当然,国内也有公司,比如像腾讯,百度等,也都开源了基于中文语料训练了的bert中文模型。这篇文章使用的是大佬开源库pycorrector提供的用中文文本fine-tuned3轮后的预训练BERT MLM模型,下载链接在该开源库的bert文件夹中的readme中。这里为什么要使用这个bert预训练模型呢,因为这个模型是基于transformers开发的,和transformers完美兼容,使用起来也很方面。当然了,bert还有其他各种框架的开源预训练模型,读者们可以自行挑选自己喜欢的。
这里也简单介绍一下transformers框架。Transformers是huggingface开源的一个机器学习框架。Transformers 提供了数以千计的预训练模型,支持 100 多种语言的文本分类、信息抽取、问答、摘要、翻译、文本生成。它的宗旨让最先进的 NLP 技术人人易用。
Transformers 提供了便于快速下载和使用的API,让你可以把预训练模型用在给定文本、在你的数据集上微调然后通过 model hub 与社区共享。同时,每个定义的 Python 模块均完全独立,方便修改和快速研究实验。
Transformers 支持三个最热门的深度学习库: Jax, PyTorch and TensorFlow — 并与之无缝整合。你可以直接使用一个框架训练你的模型然后用另一个加载和推理。
首先需要安装所需要的依赖库
pip install transformers
pip install numpy
pip install torch
安装完成之后,需要将下载的bert中文模型放在home路径下,如果是windows系统就是"C:\Users\用户名/"目录。最后就是直接提取中文词向量
import os
import numpy as np
import torch
from transformers import AutoModel, AutoTokenizer
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
model_dir = os.path.expanduser('~/.pycorrector/datasets/bert_models/chinese_finetuned_lm')
#print(model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
print("token ok")
model = AutoModel.from_pretrained(model_dir)
print("model ok")
# tensor([[ 101, 3217, 4697, 679, 6230, 3236, 102]])
inputs = tokenizer('春眠不觉晓', return_tensors='pt')
outputs = model(**inputs) # shape (1, 7, 768)
print(outputs)
v = torch.mean(outputs[0], dim=1) # shape (1, 768)
print(v)
print(v.shape)
这是打印结果
token ok
model ok
BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[ 0.0462, 0.5329, 0.1521, ..., 0.1440, -0.4105, 0.2640],
[ 0.3368, 0.5288, -0.5288, ..., 0.0241, -0.0017, 0.6836],
[ 0.8783, -0.7624, -0.3651, ..., -0.1129, -0.1835, 0.4492],
...,
[ 1.1596, 0.7187, -0.4550, ..., 0.4255, -1.0546, 0.4000],
[ 0.3089, 0.1456, -0.7718, ..., -0.3547, -1.1788, -0.1022],
[-0.1397, 0.1891, 0.0370, ..., -0.4032, 0.1067, 0.8418]]],
grad_fn=<NativeLayerNormBackward>), pooler_output=tensor([[ 0.9999, 1.0000, 1.0000, 0.9994, 0.9917, 0.9673, -0.9951, -0.9998,
0.7602, -0.9998, 1.0000, 0.8789, -0.9962, -0.9991, 0.9999, -0.9998,
-0.9808, 0.9962, 0.9995, -0.8180, 1.0000, -1.0000, -0.7954, 0.9890,
-0.0880, 0.9994, 0.9973, -0.9975, -1.0000, 0.9999, 0.9937, 0.9998,
0.7745, -1.0000, -1.0000, 0.9997, 0.7679, 0.9995, 0.9810, -0.8222,
-0.9881, -0.9682, 0.4597, -1.0000, -0.9997, 0.9258, -1.0000, -1.0000,
-0.6581, 1.0000, -0.9976, -1.0000, -0.8982, 0.9489, -0.9080, 0.9993,
-1.0000, 0.9999, 1.0000, 0.9088, 0.9998, -1.0000, 0.9745, -0.9999,
1.0000, -0.9999, -0.9999, 0.9022, 1.0000, 1.0000, 0.9859, 0.9994,
1.0000, 0.9899, 0.9669, 0.9610, -0.9999, 0.9868, -1.0000, 0.3895,
1.0000, 0.9917, -0.9991, 0.9879, -0.9999, -1.0000, -0.9998, 1.0000,
-0.9313, 1.0000, 1.0000, -0.9998, -1.0000, 0.9967, -0.9996, -0.9954,
-0.9987, 0.9998, 0.3839, -0.9951, 0.5363, -0.0769, -0.6780, -0.9602,
0.9937, 0.9996, 0.9477, -0.9996, 0.9999, -0.9088, -1.0000, -0.9984,
-0.9999, -0.9949, -0.9994, 1.0000, -0.5663, -0.9630, 0.9999, -0.9999,
0.9980, -1.0000, -0.9767, -0.9776, 0.9995, 1.0000, 0.9999, -0.9992,
0.7933, 1.0000, 0.9947, 0.9999, -1.0000, 0.9998, 0.9886, -0.9998,
-0.6499, -0.9700, 1.0000, 0.9997, 0.9998, -0.9993, 0.9999, -1.0000,
1.0000, -1.0000, 1.0000, -1.0000, -0.9996, 1.0000, 0.5835, 1.0000,
0.4881, 1.0000, -0.9987, -1.0000, 0.2071, 0.9841, 0.9772, -1.0000,
0.9966, -0.9978, -0.0712, -0.8875, -1.0000, 1.0000, -0.9943, 1.0000,
0.9936, -0.9993, -0.9932, -0.9999, -0.7205, -0.9999, -0.9955, 0.9945,
-0.6264, 0.9999, -0.9988, -0.9974, 0.9961, -0.9993, -1.0000, 0.9997,
-0.9633, 0.9966, 0.9968, -0.8558, 0.1031, -0.9476, -0.9900, 1.0000,
0.9906, 0.9335, 0.9998, 0.5868, -0.9853, -0.9261, -1.0000, -0.9779,
1.0000, -0.9802, -1.0000, 0.9919, -1.0000, 0.5570, -0.9243, -0.9926,
-0.9998, -0.9999, 0.9982, -0.3304, -0.9999, 0.9972, 0.7547, -0.7263,
-1.0000, 0.9684, 0.9999, 0.9794, 0.9992, -0.9990, -1.0000, 0.9999,
-0.9994, -0.9656, 0.9782, 1.0000, 0.9999, 0.8883, 0.9906, 1.0000,
0.8663, -1.0000, 0.9902, -1.0000, -0.9686, 1.0000, -0.9999, 0.9714,
1.0000, 0.9804, 1.0000, -0.9932, -1.0000, -0.9999, 1.0000, 0.9992,
1.0000, -0.9994, -1.0000, -0.7841, -0.9982, -1.0000, -1.0000, 0.5399,
0.9999, 1.0000, 0.9529, -0.9999, -0.9983, -0.9997, 1.0000, -0.9995,
1.0000, 0.9985, -0.9990, -0.9984, 0.4102, -0.9887, -0.9999, 0.9476,
-1.0000, -0.9998, -1.0000, 0.5216, -0.9999, -1.0000, 0.9927, 1.0000,
0.9832, -1.0000, 1.0000, 1.0000, 0.6493, -0.9897, 0.9987, -1.0000,
1.0000, -0.9998, 0.9992, -0.9910, -0.9988, 0.6841, 1.0000, 0.9991,
-0.0945, -0.8882, -0.9999, -0.9984, 0.6373, 0.9962, -0.9928, 0.9837,
-0.9835, -0.5475, -0.6533, -0.9646, -1.0000, 0.9946, 1.0000, -0.9937,
1.0000, 0.9995, 1.0000, 0.9988, -0.9979, 0.9999, -0.9568, -0.9371,
-0.9975, -0.9796, 0.9989, 0.9934, -0.9996, -1.0000, 1.0000, -0.4073,
0.0726, 0.9991, -0.9819, 0.7489, 0.9996, -0.9997, 0.9989, -0.9999,
-0.9999, 0.9999, 1.0000, 0.9999, 0.7050, -0.9942, 0.9992, -1.0000,
0.9998, -1.0000, 0.9999, 0.7262, 0.9978, -0.9997, -0.9997, 1.0000,
0.9191, -0.7063, 0.9999, -0.9997, 0.9823, 0.9794, 0.9979, 0.9994,
0.9998, 1.0000, -0.4151, -0.4769, -0.9945, -0.9979, -0.9999, -1.0000,
0.9442, -1.0000, -0.9992, -0.5763, -0.9380, 0.9991, -0.9614, 0.9741,
-0.9200, 0.7585, -0.6832, 0.7789, 0.9100, -0.9990, -0.9997, -1.0000,
-0.9996, 0.9937, 1.0000, -1.0000, 1.0000, -1.0000, -0.9945, 0.9909,
-0.9975, -0.9766, 0.9982, -1.0000, 0.9968, 1.0000, 1.0000, 0.9997,
1.0000, -0.0377, -0.9999, -0.9999, -1.0000, -1.0000, -1.0000, 0.9967,
0.5768, -1.0000, -0.9999, 0.9974, 1.0000, 0.9999, -1.0000, -0.9960,
-1.0000, -1.0000, 0.9998, -1.0000, -1.0000, 0.8403, 0.2925, 1.0000,
-0.7300, 0.9968, 0.9721, -0.8463, 0.9994, -1.0000, 0.9960, 1.0000,
0.9902, -1.0000, -0.8275, -0.3285, -1.0000, -0.5722, 0.9921, 1.0000,
-1.0000, -0.9919, -0.9993, 0.9939, 0.9998, 1.0000, 0.9999, 0.9973,
0.9966, 0.9990, -0.2278, 1.0000, 0.9019, -0.9999, 1.0000, -0.9931,
0.3810, -0.9970, 1.0000, 0.9909, 1.0000, 0.9994, 0.0773, -0.8879,
-1.0000, 0.9805, 1.0000, -0.9960, -0.9988, -1.0000, -0.9999, -0.9999,
-0.9837, 0.9481, -0.9999, -0.9999, -0.2732, 0.8276, 1.0000, 1.0000,
0.9999, -0.9998, -0.9134, 0.9975, -0.9959, 0.9502, -0.9299, -1.0000,
-0.9998, -0.9908, 1.0000, -0.9984, 0.8839, -0.9043, 0.9909, 0.9980,
-1.0000, -0.9808, -0.9997, 0.9978, 1.0000, -1.0000, 0.9995, -0.9992,
0.9993, 0.9830, 0.9992, 0.9999, -0.9392, -0.6955, -0.8634, -0.6895,
0.9410, 0.9981, -1.0000, -0.6234, 1.0000, -0.9660, 0.9999, 0.5241,
0.2568, 0.9977, 1.0000, 0.9999, 1.0000, 0.9980, 0.9982, 1.0000,
0.5588, 0.9976, -0.4912, -0.9997, 0.9614, -0.1142, 1.0000, -0.9842,
-0.9951, -1.0000, 0.8644, 1.0000, 1.0000, -0.9995, 0.9998, 0.8025,
0.6844, 0.9762, 0.9916, 0.9893, 0.3525, 0.9999, 1.0000, -1.0000,
-0.9999, -1.0000, 1.0000, 0.9998, -0.9740, -1.0000, 0.9999, -0.9981,
0.3949, 0.9960, 0.9281, -0.9850, 0.8734, -1.0000, -0.1166, 0.9119,
0.9983, -0.9729, 0.9999, -0.9999, 0.9485, 1.0000, -0.9886, 1.0000,
0.3598, -1.0000, 1.0000, -1.0000, -0.9999, -0.1344, 1.0000, 0.9999,
0.9721, -0.9948, 1.0000, -1.0000, 1.0000, -1.0000, -0.9784, -0.9999,
1.0000, -0.9999, -0.9995, -0.9961, 0.9934, -0.6841, -0.9959, 1.0000,
0.9830, -0.0785, 0.0016, -0.9982, -0.9997, -0.9993, -0.9789, -1.0000,
0.9880, 0.6295, -0.9854, -0.9989, -1.0000, 1.0000, 0.9102, -0.9993,
1.0000, 0.6416, -1.0000, 0.9998, -1.0000, 0.9974, 0.9995, 0.9282,
0.9281, -1.0000, 0.9451, 1.0000, -0.9977, -0.1010, -0.8494, -0.9989,
0.9680, 0.9994, 0.9994, -0.9982, 0.9991, 0.9782, 0.9994, -0.9993,
0.5694, -1.0000, -0.9947, -0.9771, -0.5015, -1.0000, -1.0000, 1.0000,
1.0000, 1.0000, -0.9997, -0.9630, 0.9945, 0.9997, -0.9997, -0.5486,
-0.4082, 0.9995, 0.9815, -0.9995, -0.9222, -1.0000, -0.9999, 0.8829,
0.9606, 0.9846, 1.0000, 1.0000, -0.9993, -0.9850, -1.0000, -1.0000,
1.0000, 0.9998, 1.0000, -0.9935, -0.9935, 1.0000, -0.9598, -0.9793,
-0.9996, -1.0000, -1.0000, 0.9801, -0.9996, -1.0000, 0.9997, 1.0000,
0.2486, -1.0000, -0.9970, 1.0000, 1.0000, 1.0000, 0.9682, 1.0000,
-0.9981, 0.9991, -0.9999, 1.0000, -1.0000, 1.0000, 1.0000, 0.9988,
0.9996, -0.9998, 0.9529, -1.0000, 0.9445, 0.9934, -0.9129, -0.9996,
0.7888, -0.7471, -0.9998, 1.0000, 0.9782, -0.9456, 0.9992, 0.8537,
1.0000, -0.7332, -1.0000, 0.8291, 0.9964, 0.9996, 1.0000, 0.9961,
1.0000, -0.9619, -0.9998, 0.9999, -0.9962, -0.9640, -1.0000, 1.0000,
0.9997, -1.0000, -0.9990, -0.0881, 0.9771, 1.0000, 0.9998, 0.9856,
0.9443, 0.7739, 0.9998, -1.0000, 0.9999, -0.9997, -0.9957, 1.0000,
-0.9999, 0.9999, -0.9997, 0.9967, -1.0000, 0.3715, 0.9992, 0.9983,
-0.9992, 1.0000, 0.9850, -0.9977, -0.9998, -0.9996, -0.9993, 0.9992]],
grad_fn=<TanhBackward>), hidden_states=None, past_key_values=None, attentions=None, cross_attentions=None)
tensor([[ 4.2115e-01, 2.3145e-01, -2.0698e-01, -9.7974e-02, 4.9489e-01,
2.7838e-01, -1.1931e-01, 1.2249e-01, 5.7828e-01, -1.4129e-02,
-2.6775e-01, 9.0389e-02, 1.0490e-01, 2.3848e-01, 2.4841e-01,
-5.6864e-01, -1.8439e-01, -1.8489e-01, -5.0514e-01, 1.3315e-01,
4.9129e-01, 1.8245e-01, -2.6851e-01, -1.0386e+00, -7.9079e-01,
1.4436e-01, 4.9574e-01, -9.7084e-02, 8.0020e-01, -4.6640e-01,
-2.6626e-01, -2.8174e-01, -3.6215e-01, 5.1038e-01, 6.0100e-01,
-4.5917e-01, 1.0059e-01, -3.4047e-01, 1.1231e-01, 1.1119e-01,
2.3894e-01, 1.1199e-01, 1.9337e-01, -6.2538e-01, -7.4717e-02,
-1.2579e-01, 6.5073e-01, 6.4615e-01, -5.8761e-02, 2.2246e-01,
-2.5051e-01, 8.6206e+00, 1.4232e-02, 2.9667e-01, -1.3845e+00,
1.1278e-01, 5.9608e-01, 5.9820e-02, 2.9962e-01, -2.5438e-01,
2.8565e-02, -6.6922e-01, 3.6433e-01, -2.8556e-01, -7.1027e-01,
1.4349e-01, -3.9762e-01, 2.2639e-01, -2.6528e-01, -9.4360e-02,
4.1309e-01, 3.2021e-01, -9.2492e-04, 1.7673e-01, -3.3852e-01,
8.5758e-01, 3.5135e-02, -5.2421e-01, 4.9388e-01, -1.5129e-01,
-1.0801e+00, -1.6625e-01, -5.9788e-01, -8.1008e-02, -4.5741e-01,
-8.6737e-02, -1.0806e+00, -2.2317e+00, 8.1670e-01, 2.4053e-01,
3.0502e-03, -3.6126e-02, -2.3908e-01, 5.5755e-02, 5.4440e-01,
3.0456e-01, 3.7178e-01, 2.6597e-02, -1.2701e-01, 7.3959e-01,
-1.6025e-01, 1.2154e+00, -6.3310e-01, 1.2381e+00, -2.8937e-01,
3.7735e-01, 4.1021e-01, -3.5903e-02, -8.0253e-01, 1.5286e-02,
-7.4596e-01, -1.5062e-01, -2.4371e-01, 2.9766e-01, -2.5416e-02,
2.2347e-01, -1.2874e-01, 1.6149e-01, -7.6720e-01, -2.2865e-01,
8.2740e-02, 5.8240e-01, 5.9125e-01, -9.1472e-01, 1.8060e-02,
-8.7869e-01, -2.0138e-01, 2.2509e-01, 2.0105e-01, 5.2098e-01,
1.6654e-01, -7.5825e-03, -7.6159e-01, 3.6378e-01, 8.5763e-01,
-4.7668e-01, 7.9492e-01, -1.3293e-01, -9.8156e-01, -5.2239e-02,
-1.1053e-01, -2.3038e-01, -1.8885e-01, 5.8654e-02, -7.5245e-02,
5.7493e-01, -5.8965e-01, -8.3425e-01, -9.1647e-02, -3.9919e-01,
-5.9551e-01, -3.8222e-01, -3.9354e-01, -1.2872e+00, 1.7037e-01,
6.2567e-02, 3.7266e-01, -4.3533e-01, -1.5423e-01, 8.6063e-02,
1.4821e-02, -5.9938e-01, 1.2751e-01, 3.0867e-01, 2.7782e-01,
-1.0011e+00, 8.3422e-01, -1.4378e-01, -4.3895e-01, -1.1647e-01,
3.4525e-01, -3.5966e-01, 3.6655e-01, -5.9473e-01, -3.9539e-01,
2.7654e-01, 5.3645e-02, -8.5819e-01, -2.1018e-01, -1.4444e+00,
1.9370e-01, -7.4787e-01, 6.9408e-01, -2.7827e-01, 1.2964e-01,
-8.9586e-02, -2.4729e-01, 3.4385e-01, -7.2836e-01, 1.0293e-01,
6.1057e-01, 1.2272e-01, -1.1748e+00, 2.0129e-01, 5.6979e-01,
2.6673e-02, -5.0832e-01, 3.2063e-01, 8.4628e-01, 8.5049e-01,
-1.0433e-02, -1.0727e+00, -9.1567e-01, 2.1140e-01, 3.1516e-01,
6.1549e-01, 4.2328e-01, -1.0585e-01, -4.5701e-01, 9.7370e-01,
5.0359e-01, -4.9752e-02, 3.3154e-01, 1.8237e-01, 1.4244e-01,
1.8660e-01, 2.9726e-01, -5.3041e-01, 1.0201e-02, -2.2264e-01,
1.1000e-01, 1.0814e-01, 6.9984e-01, 2.5458e-01, -3.6922e-02,
1.8352e-02, 6.2996e-01, -1.2014e-02, -5.9554e-01, 6.6884e-02,
-4.9590e-01, -1.5742e-01, -6.2079e-02, -3.6228e-01, -4.4048e-01,
4.2865e-01, 4.4279e-02, -6.4402e-01, -7.8738e-01, -4.3362e-02,
2.1974e-01, -3.6023e-01, -1.1008e+00, 5.3952e-02, -3.6299e-01,
2.6865e-01, 1.3385e-01, 1.9683e-01, -1.7107e-01, 2.3446e-01,
-2.2911e-01, -5.3821e-01, -1.9868e-01, -1.3766e+00, -2.4666e-01,
6.4290e-01, -6.9668e-01, -7.4800e-01, -5.5722e-01, -4.4907e-01,
-1.3621e-01, 4.2140e-01, 9.5959e-01, 4.4360e-01, 3.7554e-01,
-1.8354e-01, 9.7835e-02, -5.5753e-01, -3.2415e-01, -8.0041e-02,
4.4006e-01, 2.6940e-01, -3.0654e-01, -2.3053e-01, -5.2161e-01,
-6.4866e-02, 4.9422e-01, 7.5057e-01, 1.8852e-01, -3.3040e-01,
1.2468e+00, 2.9247e-02, -6.6896e-01, -9.5221e-02, -5.9996e-01,
-1.9553e-01, 3.8359e-02, 3.2242e-01, -1.8793e-01, -5.3541e-02,
-3.1233e-01, 3.8802e-01, 1.3228e-01, 1.0668e+00, 1.8578e+00,
-8.0071e-01, 1.2525e-01, -4.9549e-01, 2.0179e-01, -8.4097e-01,
1.8610e-01, -4.7953e-02, -3.0395e-01, 8.5403e-01, -5.3822e-02,
-5.0801e-01, 8.2333e-01, 5.0291e-01, -9.3993e-01, -4.6249e-01,
-4.1557e-01, -2.2973e-01, -2.2915e-01, -6.1189e-01, -2.7792e-01,
3.9598e-01, -3.4302e-01, 6.4712e-01, 2.2012e-01, 4.4055e-01,
-4.4096e-01, 6.9279e-02, -2.8966e-01, 1.0388e+00, -1.1961e-01,
-8.2876e-02, -1.4306e-01, 2.9727e-01, -1.2053e-01, 7.9857e-01,
9.0269e-03, -1.5311e-01, -5.4600e-01, 8.1083e-01, 1.1571e+00,
-3.9596e-01, 6.7261e-01, 4.7498e-01, -7.7686e-01, -2.9285e-01,
-1.5991e-01, 1.4495e-01, 2.2916e-01, 5.6729e-02, -1.9356e-01,
4.4161e-01, 4.0350e-01, 6.1182e-01, 4.2114e-01, -3.9481e-01,
-6.3641e-01, 5.6452e-02, -8.1091e-01, -1.0413e-03, -4.5391e-01,
1.7060e+00, -4.5662e-01, 5.8219e-02, 1.4496e+00, -1.3005e-01,
1.2138e-01, 1.8672e-01, -4.0238e-01, -2.9981e-01, 2.5241e-01,
-2.9240e-01, -4.5109e-02, -6.0901e-01, 7.5708e-01, -5.5898e-01,
6.4741e-01, 5.7061e-01, -5.5011e-01, -7.6833e-01, -1.2637e-01,
-4.7630e-01, 3.6280e-01, -1.6268e-01, 4.3015e-01, 3.1131e-02,
-1.7380e-01, -2.1078e-01, -6.3284e-02, -7.5465e-01, -1.6997e-01,
4.2480e-03, 4.7350e-01, -9.3531e-01, -1.1029e-01, -1.9922e-01,
4.4874e-01, -1.3954e-01, -6.0780e-01, -9.7392e-02, 5.7029e-01,
1.0055e+00, 2.3332e-01, -8.8975e-02, -6.3464e-02, 1.2084e+00,
7.7065e-01, -1.1713e-01, 1.7554e-01, -7.2155e-01, 7.5223e-02,
3.6250e-01, 4.5526e-01, 6.7044e-01, 2.8244e-01, -1.2149e+00,
-1.1918e+00, 2.2679e-01, 9.8096e-01, 2.1820e-01, 1.4098e-01,
-1.7236e-01, 2.2309e-01, 1.5257e-01, 6.6268e-01, -1.0398e-01,
-3.0996e-01, 7.0454e-01, -2.6190e-01, -1.0244e-01, 6.4382e-01,
-4.8694e-01, -5.3009e-01, 2.8935e-01, -5.8808e-01, -1.4827e-01,
-8.6694e-01, -7.6133e-02, 2.6520e-01, -6.9277e-01, 7.4502e-01,
8.5377e-01, -3.6214e-01, 4.8850e-01, -3.8880e-01, 5.1802e-01,
-3.7141e-01, -7.7974e-01, 6.9056e-01, -4.9195e-01, 1.2370e-01,
5.5867e-01, 6.1279e-01, -3.0551e-01, 6.3116e-01, -2.2843e-01,
7.1648e-02, -4.6328e-01, -5.1349e-01, -3.3412e-01, 1.8214e-01,
-6.5556e-01, 4.2179e-01, 6.4876e-01, 6.9453e-02, 1.9235e-01,
-3.9381e-01, 1.8176e-01, 2.6222e-01, 4.9721e-01, 4.9734e-01,
3.7234e-02, -4.3901e-01, -1.4314e-01, 7.3539e-01, -3.5594e-02,
1.0551e-01, 8.7454e-02, 1.1364e-01, 6.1150e-01, -4.7138e-01,
1.0484e-01, -1.5497e-01, -1.4640e-02, -1.2692e+00, 8.9680e-03,
1.6948e-01, -8.3116e-01, 1.3112e+00, -8.9783e-01, 7.2376e-01,
3.6621e-01, 1.6011e-01, 1.2368e-02, -6.5463e-01, -4.1487e-02,
-9.6702e-01, -2.8059e-01, 2.6916e-01, 6.7283e-01, -5.2141e-01,
-2.5983e-01, 1.8001e-01, 1.0120e-01, -2.0587e-01, -4.9314e-01,
-6.8490e-01, -1.5255e+00, 1.8295e-01, -2.5930e-01, -1.2546e+00,
-6.9589e-02, 4.3109e-01, -2.0965e-03, -3.3771e-01, 3.5798e-01,
9.4842e-01, -4.0370e-01, -5.9109e-01, -3.0457e-02, 8.6064e-01,
2.0668e-01, -3.9533e-02, -1.1155e+00, 1.6056e-01, -6.5845e-01,
8.1298e-01, -6.7716e-01, 7.3191e-01, -6.5311e-03, -8.2889e-01,
2.2935e-01, -1.0802e+00, -2.7448e-01, 1.6502e-01, 6.5893e-01,
-2.7223e-01, -7.5991e-02, 2.9284e-01, -5.6872e-01, 6.5480e-01,
7.2336e-02, 1.8865e-01, -1.0025e+00, -3.8879e-01, -6.7788e-01,
-1.2397e-01, -3.5319e-01, -2.9062e-01, -2.1113e+00, -2.8191e-01,
-7.8004e-01, 1.7187e-01, 7.2854e-01, -1.2840e-01, 1.7805e-01,
-9.9736e-01, 6.6500e-01, -1.1308e+00, 8.9092e-02, -7.6484e-02,
-3.3015e-01, -8.3021e-03, 2.0718e-01, 1.5108e-01, -4.2013e-01,
2.6676e-01, -7.4054e-02, 3.5822e-01, -5.4754e-01, -6.3731e-01,
1.6186e-01, -3.3157e-01, 4.8395e-01, -3.2022e-01, -7.7350e-01,
7.0298e-01, 1.9313e-01, -1.4782e+00, 1.6191e-01, 7.8461e-01,
7.4405e-01, -4.7152e-01, 9.6684e-02, 6.4955e-01, -3.2820e-01,
7.1990e-01, 1.3264e+00, -4.5643e-01, 3.7712e-01, 3.0739e-01,
4.9210e-01, -8.8734e-01, 7.0449e-01, 4.2371e-01, -6.9552e-01,
-4.3504e-01, 7.1409e-02, -3.4701e-01, 1.1066e-01, -6.0244e-01,
4.9573e-02, -1.2310e-01, -2.2023e-01, -8.0928e-01, 1.0534e-01,
2.8099e-01, -4.2681e-01, 4.7401e-02, 1.5718e-01, 3.3319e-01,
-6.7182e-01, 7.1283e-01, 1.1922e+00, -9.8484e-02, 1.4709e-02,
7.8795e-01, 3.9003e-01, 7.0788e-02, -5.5820e-01, -1.6251e-01,
4.7181e-01, -4.2967e-01, 2.4140e-01, -1.6781e-01, -7.4418e-01,
4.0156e-01, 1.0813e+00, 5.5052e-01, -4.2373e-01, -6.1423e-01,
-1.7941e-01, 2.7244e-02, 4.9143e-01, 2.2138e-01, -7.6815e-01,
-8.6145e-01, 4.8813e-01, -8.6220e-01, 4.5893e-01, -1.9820e-02,
1.0818e-01, -4.0966e-01, 7.4174e-01, -3.6610e-01, 2.1310e-02,
8.0221e-01, -1.8562e-01, 5.8339e-02, 4.0976e-01, 7.0982e-01,
-3.1056e-01, 1.5611e-01, 8.4351e-02, 1.8569e-01, -1.6442e-02,
6.1764e-01, -5.8445e-02, -3.6498e-01, 4.0751e-01, -3.0788e-01,
-9.5406e-02, 7.1717e-01, 2.4616e-01, -3.7505e-01, -2.1185e-01,
3.5740e-01, -4.4704e-01, -1.9303e-01, -2.2978e-02, -1.7094e-01,
-5.1454e-01, -7.2254e-01, -7.0447e-01, 2.7950e-01, 6.4647e-01,
7.6261e-01, -3.2582e-01, 1.0665e+00, -8.2867e-01, 3.0364e-01,
-3.7458e-02, -7.8399e-01, -2.0923e-01, -2.5211e-01, -2.0208e-01,
2.0298e-01, -6.1091e-01, -1.8543e-01, -5.2403e-01, 1.6136e-05,
5.4596e-01, -9.8686e-01, -3.3052e-01, 3.0897e-01, 4.7233e-01,
3.8098e-02, 9.4033e-01, 3.2290e-01, 8.6700e-03, -2.4415e-01,
3.9516e+00, -4.7525e-01, 7.1882e-01, 2.5092e-01, -1.0444e-01,
1.9712e-01, 7.1822e-01, -1.3286e-01, -3.0513e-01, -1.0075e-01,
6.1208e-02, -3.2429e-01, -4.3805e-01, 2.2502e-01, 3.1383e-03,
-8.6180e-01, -5.3833e-01, 2.1707e-01, -8.0440e-02, -1.5623e-01,
1.3789e+00, 2.2651e-01, -6.3904e-01, -5.1710e-01, -6.7277e-01,
-6.8123e-01, -4.4332e-01, 1.1069e+00, 3.0457e-01, -4.2791e-01,
-2.5976e-01, -4.9939e-01, 8.8211e-02, 5.7401e-01, -1.2304e-01,
-5.9881e-01, 4.9584e-01, 1.1418e-01, 3.9800e-01, 3.0608e-02,
8.9072e-01, 2.0687e-01, 2.7497e-01, 1.6121e-01, 1.1269e+00,
-3.1077e-01, -1.0664e+00, -1.5314e-01, -4.9027e-01, 1.9091e-01,
3.8845e-01, 1.7292e-01, -6.2660e-01, -1.0976e-01, 5.3393e-01,
-6.6391e-01, 5.6167e-01, 8.1336e-01, -3.8923e-04, 1.5196e-01,
-3.2085e-01, 2.7172e+00, -5.1197e-01, 9.2552e-01, -4.9456e-01,
-3.3745e-01, 2.8565e-01, 2.3598e-01, -1.4363e-01, 1.5846e-01,
-8.3367e-02, -4.9820e-01, 5.0628e-01]], grad_fn=<MeanBackward1>)
torch.Size([1, 768])
需要说明的是,这里的bert中文提取是直接将一行中文分割成单字再做one-hot编码,比如这里的input是 “春眠不觉晓” ,tokenizer这个函数则会将其转变成tensor([[ 101, 3217, 4697, 679, 6230, 3236, 102]]), 其中 101, 102分别对应的是"[CLS]" "[SEP]"这两个字符,中间的五个数字才是输入中文的one-hot编码。
整个句子经过bert特征提取之后,输出的是17768维的矩阵,其实就是每个字的特征为768,最后v = torch.mean(outputs[0], dim=1) # shape (1, 768)这行代码是将整个句子所有字的特征做一个加和求平均,得到的是整个句子的特征向量
BERT 模型详解
BERT模型详解
transformers