使用transformers框架导入bert模型提取中文词向量

导言

在笔者的上一篇文章大白话讲懂word2vec原理和如何使用中提到了如何将词语转变成计算机能够识别的语言,即将文本数据转换成计算机能够运算的数字或者向量这个概念,并详细阐述了word2vec这个模型的原理,如何在gensim框架下使用word2vec将单词转变成一个能够表达单词特征的向量。但是在gensim框架下使用word2vec模型提取特征向量之前,需要准备一些场景中的语料,同事也需要对语料进行分词操作,然后再输入到模型中进行训练,最后才可以用训练好的模型进行特征提取。整个流程还是较为繁琐,同时模型效果取决于语料覆盖度是否全面和训练超参数的选择。其实呢,nlp也有一些现成的文本特征提取的模型,比如前些年比较火爆的google推出的bert模型。本篇文章就简单的介绍一下bert模型,以及新手如何快速上手使用bert提取文本特征,以便去做后续的nlp任务。

bert模型和Transformers框架

bert是由Google的雅各布·德夫林和同事在2018年创建的,并发表了论文Pre-training of Deep Bidirectional Transformers for Language Understanding,bert是”Bidirectional Encoder Representations from Transformers”的首字母缩, 直译就是基于变换器的双向编码器表示技术。

最初的英语bert发布时提供两种类型的预训练模型:(1)BERTBASE模型,一个12层,768维,12个自注意头(self attention head),110M参数的神经网络结构;(2)BERTLARGE模型,一个24层,1024维,16个自注意头,340M参数的神经网络结构。两者的训练语料都是BooksCorpus以及英语维基百科语料,单词量分别是8亿以及25亿。

当然,国内也有公司,比如像腾讯,百度等,也都开源了基于中文语料训练了的bert中文模型。这篇文章使用的是大佬开源库pycorrector提供的用中文文本fine-tuned3轮后的预训练BERT MLM模型,下载链接在该开源库的bert文件夹中的readme中。这里为什么要使用这个bert预训练模型呢,因为这个模型是基于transformers开发的,和transformers完美兼容,使用起来也很方面。当然了,bert还有其他各种框架的开源预训练模型,读者们可以自行挑选自己喜欢的。

这里也简单介绍一下transformers框架。Transformers是huggingface开源的一个机器学习框架。Transformers 提供了数以千计的预训练模型,支持 100 多种语言的文本分类、信息抽取、问答、摘要、翻译、文本生成。它的宗旨让最先进的 NLP 技术人人易用。

Transformers 提供了便于快速下载和使用的API,让你可以把预训练模型用在给定文本、在你的数据集上微调然后通过 model hub 与社区共享。同时,每个定义的 Python 模块均完全独立,方便修改和快速研究实验。

Transformers 支持三个最热门的深度学习库: Jax, PyTorch and TensorFlow — 并与之无缝整合。你可以直接使用一个框架训练你的模型然后用另一个加载和推理。

bert模型提取中文词向量

首先需要安装所需要的依赖库

pip install transformers
pip install numpy
pip install torch

安装完成之后,需要将下载的bert中文模型放在home路径下,如果是windows系统就是"C:\Users\用户名/"目录。最后就是直接提取中文词向量

import os

import numpy as np
import torch
from transformers import AutoModel, AutoTokenizer

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
model_dir = os.path.expanduser('~/.pycorrector/datasets/bert_models/chinese_finetuned_lm') 
#print(model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
print("token ok")
model = AutoModel.from_pretrained(model_dir)
print("model ok")
# tensor([[ 101, 3217, 4697,  679, 6230, 3236,  102]])
inputs = tokenizer('春眠不觉晓', return_tensors='pt')
outputs = model(**inputs)  # shape (1, 7, 768)
print(outputs)
v = torch.mean(outputs[0], dim=1)  # shape (1, 768)
print(v)
print(v.shape)

这是打印结果

token ok
model ok
BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[ 0.0462,  0.5329,  0.1521,  ...,  0.1440, -0.4105,  0.2640],
         [ 0.3368,  0.5288, -0.5288,  ...,  0.0241, -0.0017,  0.6836],
         [ 0.8783, -0.7624, -0.3651,  ..., -0.1129, -0.1835,  0.4492],
         ...,
         [ 1.1596,  0.7187, -0.4550,  ...,  0.4255, -1.0546,  0.4000],
         [ 0.3089,  0.1456, -0.7718,  ..., -0.3547, -1.1788, -0.1022],
         [-0.1397,  0.1891,  0.0370,  ..., -0.4032,  0.1067,  0.8418]]],
       grad_fn=<NativeLayerNormBackward>), pooler_output=tensor([[ 0.9999,  1.0000,  1.0000,  0.9994,  0.9917,  0.9673, -0.9951, -0.9998,
          0.7602, -0.9998,  1.0000,  0.8789, -0.9962, -0.9991,  0.9999, -0.9998,
         -0.9808,  0.9962,  0.9995, -0.8180,  1.0000, -1.0000, -0.7954,  0.9890,
         -0.0880,  0.9994,  0.9973, -0.9975, -1.0000,  0.9999,  0.9937,  0.9998,
          0.7745, -1.0000, -1.0000,  0.9997,  0.7679,  0.9995,  0.9810, -0.8222,
         -0.9881, -0.9682,  0.4597, -1.0000, -0.9997,  0.9258, -1.0000, -1.0000,
         -0.6581,  1.0000, -0.9976, -1.0000, -0.8982,  0.9489, -0.9080,  0.9993,
         -1.0000,  0.9999,  1.0000,  0.9088,  0.9998, -1.0000,  0.9745, -0.9999,
          1.0000, -0.9999, -0.9999,  0.9022,  1.0000,  1.0000,  0.9859,  0.9994,
          1.0000,  0.9899,  0.9669,  0.9610, -0.9999,  0.9868, -1.0000,  0.3895,
          1.0000,  0.9917, -0.9991,  0.9879, -0.9999, -1.0000, -0.9998,  1.0000,
         -0.9313,  1.0000,  1.0000, -0.9998, -1.0000,  0.9967, -0.9996, -0.9954,
         -0.9987,  0.9998,  0.3839, -0.9951,  0.5363, -0.0769, -0.6780, -0.9602,
          0.9937,  0.9996,  0.9477, -0.9996,  0.9999, -0.9088, -1.0000, -0.9984,
         -0.9999, -0.9949, -0.9994,  1.0000, -0.5663, -0.9630,  0.9999, -0.9999,
          0.9980, -1.0000, -0.9767, -0.9776,  0.9995,  1.0000,  0.9999, -0.9992,
          0.7933,  1.0000,  0.9947,  0.9999, -1.0000,  0.9998,  0.9886, -0.9998,
         -0.6499, -0.9700,  1.0000,  0.9997,  0.9998, -0.9993,  0.9999, -1.0000,
          1.0000, -1.0000,  1.0000, -1.0000, -0.9996,  1.0000,  0.5835,  1.0000,
          0.4881,  1.0000, -0.9987, -1.0000,  0.2071,  0.9841,  0.9772, -1.0000,
          0.9966, -0.9978, -0.0712, -0.8875, -1.0000,  1.0000, -0.9943,  1.0000,
          0.9936, -0.9993, -0.9932, -0.9999, -0.7205, -0.9999, -0.9955,  0.9945,
         -0.6264,  0.9999, -0.9988, -0.9974,  0.9961, -0.9993, -1.0000,  0.9997,
         -0.9633,  0.9966,  0.9968, -0.8558,  0.1031, -0.9476, -0.9900,  1.0000,
          0.9906,  0.9335,  0.9998,  0.5868, -0.9853, -0.9261, -1.0000, -0.9779,
          1.0000, -0.9802, -1.0000,  0.9919, -1.0000,  0.5570, -0.9243, -0.9926,
         -0.9998, -0.9999,  0.9982, -0.3304, -0.9999,  0.9972,  0.7547, -0.7263,
         -1.0000,  0.9684,  0.9999,  0.9794,  0.9992, -0.9990, -1.0000,  0.9999,
         -0.9994, -0.9656,  0.9782,  1.0000,  0.9999,  0.8883,  0.9906,  1.0000,
          0.8663, -1.0000,  0.9902, -1.0000, -0.9686,  1.0000, -0.9999,  0.9714,
          1.0000,  0.9804,  1.0000, -0.9932, -1.0000, -0.9999,  1.0000,  0.9992,
          1.0000, -0.9994, -1.0000, -0.7841, -0.9982, -1.0000, -1.0000,  0.5399,
          0.9999,  1.0000,  0.9529, -0.9999, -0.9983, -0.9997,  1.0000, -0.9995,
          1.0000,  0.9985, -0.9990, -0.9984,  0.4102, -0.9887, -0.9999,  0.9476,
         -1.0000, -0.9998, -1.0000,  0.5216, -0.9999, -1.0000,  0.9927,  1.0000,
          0.9832, -1.0000,  1.0000,  1.0000,  0.6493, -0.9897,  0.9987, -1.0000,
          1.0000, -0.9998,  0.9992, -0.9910, -0.9988,  0.6841,  1.0000,  0.9991,
         -0.0945, -0.8882, -0.9999, -0.9984,  0.6373,  0.9962, -0.9928,  0.9837,
         -0.9835, -0.5475, -0.6533, -0.9646, -1.0000,  0.9946,  1.0000, -0.9937,
          1.0000,  0.9995,  1.0000,  0.9988, -0.9979,  0.9999, -0.9568, -0.9371,
         -0.9975, -0.9796,  0.9989,  0.9934, -0.9996, -1.0000,  1.0000, -0.4073,
          0.0726,  0.9991, -0.9819,  0.7489,  0.9996, -0.9997,  0.9989, -0.9999,
         -0.9999,  0.9999,  1.0000,  0.9999,  0.7050, -0.9942,  0.9992, -1.0000,
          0.9998, -1.0000,  0.9999,  0.7262,  0.9978, -0.9997, -0.9997,  1.0000,
          0.9191, -0.7063,  0.9999, -0.9997,  0.9823,  0.9794,  0.9979,  0.9994,
          0.9998,  1.0000, -0.4151, -0.4769, -0.9945, -0.9979, -0.9999, -1.0000,
          0.9442, -1.0000, -0.9992, -0.5763, -0.9380,  0.9991, -0.9614,  0.9741,
         -0.9200,  0.7585, -0.6832,  0.7789,  0.9100, -0.9990, -0.9997, -1.0000,
         -0.9996,  0.9937,  1.0000, -1.0000,  1.0000, -1.0000, -0.9945,  0.9909,
         -0.9975, -0.9766,  0.9982, -1.0000,  0.9968,  1.0000,  1.0000,  0.9997,
          1.0000, -0.0377, -0.9999, -0.9999, -1.0000, -1.0000, -1.0000,  0.9967,
          0.5768, -1.0000, -0.9999,  0.9974,  1.0000,  0.9999, -1.0000, -0.9960,
         -1.0000, -1.0000,  0.9998, -1.0000, -1.0000,  0.8403,  0.2925,  1.0000,
         -0.7300,  0.9968,  0.9721, -0.8463,  0.9994, -1.0000,  0.9960,  1.0000,
          0.9902, -1.0000, -0.8275, -0.3285, -1.0000, -0.5722,  0.9921,  1.0000,
         -1.0000, -0.9919, -0.9993,  0.9939,  0.9998,  1.0000,  0.9999,  0.9973,
          0.9966,  0.9990, -0.2278,  1.0000,  0.9019, -0.9999,  1.0000, -0.9931,
          0.3810, -0.9970,  1.0000,  0.9909,  1.0000,  0.9994,  0.0773, -0.8879,
         -1.0000,  0.9805,  1.0000, -0.9960, -0.9988, -1.0000, -0.9999, -0.9999,
         -0.9837,  0.9481, -0.9999, -0.9999, -0.2732,  0.8276,  1.0000,  1.0000,
          0.9999, -0.9998, -0.9134,  0.9975, -0.9959,  0.9502, -0.9299, -1.0000,
         -0.9998, -0.9908,  1.0000, -0.9984,  0.8839, -0.9043,  0.9909,  0.9980,
         -1.0000, -0.9808, -0.9997,  0.9978,  1.0000, -1.0000,  0.9995, -0.9992,
          0.9993,  0.9830,  0.9992,  0.9999, -0.9392, -0.6955, -0.8634, -0.6895,
          0.9410,  0.9981, -1.0000, -0.6234,  1.0000, -0.9660,  0.9999,  0.5241,
          0.2568,  0.9977,  1.0000,  0.9999,  1.0000,  0.9980,  0.9982,  1.0000,
          0.5588,  0.9976, -0.4912, -0.9997,  0.9614, -0.1142,  1.0000, -0.9842,
         -0.9951, -1.0000,  0.8644,  1.0000,  1.0000, -0.9995,  0.9998,  0.8025,
          0.6844,  0.9762,  0.9916,  0.9893,  0.3525,  0.9999,  1.0000, -1.0000,
         -0.9999, -1.0000,  1.0000,  0.9998, -0.9740, -1.0000,  0.9999, -0.9981,
          0.3949,  0.9960,  0.9281, -0.9850,  0.8734, -1.0000, -0.1166,  0.9119,
          0.9983, -0.9729,  0.9999, -0.9999,  0.9485,  1.0000, -0.9886,  1.0000,
          0.3598, -1.0000,  1.0000, -1.0000, -0.9999, -0.1344,  1.0000,  0.9999,
          0.9721, -0.9948,  1.0000, -1.0000,  1.0000, -1.0000, -0.9784, -0.9999,
          1.0000, -0.9999, -0.9995, -0.9961,  0.9934, -0.6841, -0.9959,  1.0000,
          0.9830, -0.0785,  0.0016, -0.9982, -0.9997, -0.9993, -0.9789, -1.0000,
          0.9880,  0.6295, -0.9854, -0.9989, -1.0000,  1.0000,  0.9102, -0.9993,
          1.0000,  0.6416, -1.0000,  0.9998, -1.0000,  0.9974,  0.9995,  0.9282,
          0.9281, -1.0000,  0.9451,  1.0000, -0.9977, -0.1010, -0.8494, -0.9989,
          0.9680,  0.9994,  0.9994, -0.9982,  0.9991,  0.9782,  0.9994, -0.9993,
          0.5694, -1.0000, -0.9947, -0.9771, -0.5015, -1.0000, -1.0000,  1.0000,
          1.0000,  1.0000, -0.9997, -0.9630,  0.9945,  0.9997, -0.9997, -0.5486,
         -0.4082,  0.9995,  0.9815, -0.9995, -0.9222, -1.0000, -0.9999,  0.8829,
          0.9606,  0.9846,  1.0000,  1.0000, -0.9993, -0.9850, -1.0000, -1.0000,
          1.0000,  0.9998,  1.0000, -0.9935, -0.9935,  1.0000, -0.9598, -0.9793,
         -0.9996, -1.0000, -1.0000,  0.9801, -0.9996, -1.0000,  0.9997,  1.0000,
          0.2486, -1.0000, -0.9970,  1.0000,  1.0000,  1.0000,  0.9682,  1.0000,
         -0.9981,  0.9991, -0.9999,  1.0000, -1.0000,  1.0000,  1.0000,  0.9988,
          0.9996, -0.9998,  0.9529, -1.0000,  0.9445,  0.9934, -0.9129, -0.9996,
          0.7888, -0.7471, -0.9998,  1.0000,  0.9782, -0.9456,  0.9992,  0.8537,
          1.0000, -0.7332, -1.0000,  0.8291,  0.9964,  0.9996,  1.0000,  0.9961,
          1.0000, -0.9619, -0.9998,  0.9999, -0.9962, -0.9640, -1.0000,  1.0000,
          0.9997, -1.0000, -0.9990, -0.0881,  0.9771,  1.0000,  0.9998,  0.9856,
          0.9443,  0.7739,  0.9998, -1.0000,  0.9999, -0.9997, -0.9957,  1.0000,
         -0.9999,  0.9999, -0.9997,  0.9967, -1.0000,  0.3715,  0.9992,  0.9983,
         -0.9992,  1.0000,  0.9850, -0.9977, -0.9998, -0.9996, -0.9993,  0.9992]],
       grad_fn=<TanhBackward>), hidden_states=None, past_key_values=None, attentions=None, cross_attentions=None)
tensor([[ 4.2115e-01,  2.3145e-01, -2.0698e-01, -9.7974e-02,  4.9489e-01,
          2.7838e-01, -1.1931e-01,  1.2249e-01,  5.7828e-01, -1.4129e-02,
         -2.6775e-01,  9.0389e-02,  1.0490e-01,  2.3848e-01,  2.4841e-01,
         -5.6864e-01, -1.8439e-01, -1.8489e-01, -5.0514e-01,  1.3315e-01,
          4.9129e-01,  1.8245e-01, -2.6851e-01, -1.0386e+00, -7.9079e-01,
          1.4436e-01,  4.9574e-01, -9.7084e-02,  8.0020e-01, -4.6640e-01,
         -2.6626e-01, -2.8174e-01, -3.6215e-01,  5.1038e-01,  6.0100e-01,
         -4.5917e-01,  1.0059e-01, -3.4047e-01,  1.1231e-01,  1.1119e-01,
          2.3894e-01,  1.1199e-01,  1.9337e-01, -6.2538e-01, -7.4717e-02,
         -1.2579e-01,  6.5073e-01,  6.4615e-01, -5.8761e-02,  2.2246e-01,
         -2.5051e-01,  8.6206e+00,  1.4232e-02,  2.9667e-01, -1.3845e+00,
          1.1278e-01,  5.9608e-01,  5.9820e-02,  2.9962e-01, -2.5438e-01,
          2.8565e-02, -6.6922e-01,  3.6433e-01, -2.8556e-01, -7.1027e-01,
          1.4349e-01, -3.9762e-01,  2.2639e-01, -2.6528e-01, -9.4360e-02,
          4.1309e-01,  3.2021e-01, -9.2492e-04,  1.7673e-01, -3.3852e-01,
          8.5758e-01,  3.5135e-02, -5.2421e-01,  4.9388e-01, -1.5129e-01,
         -1.0801e+00, -1.6625e-01, -5.9788e-01, -8.1008e-02, -4.5741e-01,
         -8.6737e-02, -1.0806e+00, -2.2317e+00,  8.1670e-01,  2.4053e-01,
          3.0502e-03, -3.6126e-02, -2.3908e-01,  5.5755e-02,  5.4440e-01,
          3.0456e-01,  3.7178e-01,  2.6597e-02, -1.2701e-01,  7.3959e-01,
         -1.6025e-01,  1.2154e+00, -6.3310e-01,  1.2381e+00, -2.8937e-01,
          3.7735e-01,  4.1021e-01, -3.5903e-02, -8.0253e-01,  1.5286e-02,
         -7.4596e-01, -1.5062e-01, -2.4371e-01,  2.9766e-01, -2.5416e-02,
          2.2347e-01, -1.2874e-01,  1.6149e-01, -7.6720e-01, -2.2865e-01,
          8.2740e-02,  5.8240e-01,  5.9125e-01, -9.1472e-01,  1.8060e-02,
         -8.7869e-01, -2.0138e-01,  2.2509e-01,  2.0105e-01,  5.2098e-01,
          1.6654e-01, -7.5825e-03, -7.6159e-01,  3.6378e-01,  8.5763e-01,
         -4.7668e-01,  7.9492e-01, -1.3293e-01, -9.8156e-01, -5.2239e-02,
         -1.1053e-01, -2.3038e-01, -1.8885e-01,  5.8654e-02, -7.5245e-02,
          5.7493e-01, -5.8965e-01, -8.3425e-01, -9.1647e-02, -3.9919e-01,
         -5.9551e-01, -3.8222e-01, -3.9354e-01, -1.2872e+00,  1.7037e-01,
          6.2567e-02,  3.7266e-01, -4.3533e-01, -1.5423e-01,  8.6063e-02,
          1.4821e-02, -5.9938e-01,  1.2751e-01,  3.0867e-01,  2.7782e-01,
         -1.0011e+00,  8.3422e-01, -1.4378e-01, -4.3895e-01, -1.1647e-01,
          3.4525e-01, -3.5966e-01,  3.6655e-01, -5.9473e-01, -3.9539e-01,
          2.7654e-01,  5.3645e-02, -8.5819e-01, -2.1018e-01, -1.4444e+00,
          1.9370e-01, -7.4787e-01,  6.9408e-01, -2.7827e-01,  1.2964e-01,
         -8.9586e-02, -2.4729e-01,  3.4385e-01, -7.2836e-01,  1.0293e-01,
          6.1057e-01,  1.2272e-01, -1.1748e+00,  2.0129e-01,  5.6979e-01,
          2.6673e-02, -5.0832e-01,  3.2063e-01,  8.4628e-01,  8.5049e-01,
         -1.0433e-02, -1.0727e+00, -9.1567e-01,  2.1140e-01,  3.1516e-01,
          6.1549e-01,  4.2328e-01, -1.0585e-01, -4.5701e-01,  9.7370e-01,
          5.0359e-01, -4.9752e-02,  3.3154e-01,  1.8237e-01,  1.4244e-01,
          1.8660e-01,  2.9726e-01, -5.3041e-01,  1.0201e-02, -2.2264e-01,
          1.1000e-01,  1.0814e-01,  6.9984e-01,  2.5458e-01, -3.6922e-02,
          1.8352e-02,  6.2996e-01, -1.2014e-02, -5.9554e-01,  6.6884e-02,
         -4.9590e-01, -1.5742e-01, -6.2079e-02, -3.6228e-01, -4.4048e-01,
          4.2865e-01,  4.4279e-02, -6.4402e-01, -7.8738e-01, -4.3362e-02,
          2.1974e-01, -3.6023e-01, -1.1008e+00,  5.3952e-02, -3.6299e-01,
          2.6865e-01,  1.3385e-01,  1.9683e-01, -1.7107e-01,  2.3446e-01,
         -2.2911e-01, -5.3821e-01, -1.9868e-01, -1.3766e+00, -2.4666e-01,
          6.4290e-01, -6.9668e-01, -7.4800e-01, -5.5722e-01, -4.4907e-01,
         -1.3621e-01,  4.2140e-01,  9.5959e-01,  4.4360e-01,  3.7554e-01,
         -1.8354e-01,  9.7835e-02, -5.5753e-01, -3.2415e-01, -8.0041e-02,
          4.4006e-01,  2.6940e-01, -3.0654e-01, -2.3053e-01, -5.2161e-01,
         -6.4866e-02,  4.9422e-01,  7.5057e-01,  1.8852e-01, -3.3040e-01,
          1.2468e+00,  2.9247e-02, -6.6896e-01, -9.5221e-02, -5.9996e-01,
         -1.9553e-01,  3.8359e-02,  3.2242e-01, -1.8793e-01, -5.3541e-02,
         -3.1233e-01,  3.8802e-01,  1.3228e-01,  1.0668e+00,  1.8578e+00,
         -8.0071e-01,  1.2525e-01, -4.9549e-01,  2.0179e-01, -8.4097e-01,
          1.8610e-01, -4.7953e-02, -3.0395e-01,  8.5403e-01, -5.3822e-02,
         -5.0801e-01,  8.2333e-01,  5.0291e-01, -9.3993e-01, -4.6249e-01,
         -4.1557e-01, -2.2973e-01, -2.2915e-01, -6.1189e-01, -2.7792e-01,
          3.9598e-01, -3.4302e-01,  6.4712e-01,  2.2012e-01,  4.4055e-01,
         -4.4096e-01,  6.9279e-02, -2.8966e-01,  1.0388e+00, -1.1961e-01,
         -8.2876e-02, -1.4306e-01,  2.9727e-01, -1.2053e-01,  7.9857e-01,
          9.0269e-03, -1.5311e-01, -5.4600e-01,  8.1083e-01,  1.1571e+00,
         -3.9596e-01,  6.7261e-01,  4.7498e-01, -7.7686e-01, -2.9285e-01,
         -1.5991e-01,  1.4495e-01,  2.2916e-01,  5.6729e-02, -1.9356e-01,
          4.4161e-01,  4.0350e-01,  6.1182e-01,  4.2114e-01, -3.9481e-01,
         -6.3641e-01,  5.6452e-02, -8.1091e-01, -1.0413e-03, -4.5391e-01,
          1.7060e+00, -4.5662e-01,  5.8219e-02,  1.4496e+00, -1.3005e-01,
          1.2138e-01,  1.8672e-01, -4.0238e-01, -2.9981e-01,  2.5241e-01,
         -2.9240e-01, -4.5109e-02, -6.0901e-01,  7.5708e-01, -5.5898e-01,
          6.4741e-01,  5.7061e-01, -5.5011e-01, -7.6833e-01, -1.2637e-01,
         -4.7630e-01,  3.6280e-01, -1.6268e-01,  4.3015e-01,  3.1131e-02,
         -1.7380e-01, -2.1078e-01, -6.3284e-02, -7.5465e-01, -1.6997e-01,
          4.2480e-03,  4.7350e-01, -9.3531e-01, -1.1029e-01, -1.9922e-01,
          4.4874e-01, -1.3954e-01, -6.0780e-01, -9.7392e-02,  5.7029e-01,
          1.0055e+00,  2.3332e-01, -8.8975e-02, -6.3464e-02,  1.2084e+00,
          7.7065e-01, -1.1713e-01,  1.7554e-01, -7.2155e-01,  7.5223e-02,
          3.6250e-01,  4.5526e-01,  6.7044e-01,  2.8244e-01, -1.2149e+00,
         -1.1918e+00,  2.2679e-01,  9.8096e-01,  2.1820e-01,  1.4098e-01,
         -1.7236e-01,  2.2309e-01,  1.5257e-01,  6.6268e-01, -1.0398e-01,
         -3.0996e-01,  7.0454e-01, -2.6190e-01, -1.0244e-01,  6.4382e-01,
         -4.8694e-01, -5.3009e-01,  2.8935e-01, -5.8808e-01, -1.4827e-01,
         -8.6694e-01, -7.6133e-02,  2.6520e-01, -6.9277e-01,  7.4502e-01,
          8.5377e-01, -3.6214e-01,  4.8850e-01, -3.8880e-01,  5.1802e-01,
         -3.7141e-01, -7.7974e-01,  6.9056e-01, -4.9195e-01,  1.2370e-01,
          5.5867e-01,  6.1279e-01, -3.0551e-01,  6.3116e-01, -2.2843e-01,
          7.1648e-02, -4.6328e-01, -5.1349e-01, -3.3412e-01,  1.8214e-01,
         -6.5556e-01,  4.2179e-01,  6.4876e-01,  6.9453e-02,  1.9235e-01,
         -3.9381e-01,  1.8176e-01,  2.6222e-01,  4.9721e-01,  4.9734e-01,
          3.7234e-02, -4.3901e-01, -1.4314e-01,  7.3539e-01, -3.5594e-02,
          1.0551e-01,  8.7454e-02,  1.1364e-01,  6.1150e-01, -4.7138e-01,
          1.0484e-01, -1.5497e-01, -1.4640e-02, -1.2692e+00,  8.9680e-03,
          1.6948e-01, -8.3116e-01,  1.3112e+00, -8.9783e-01,  7.2376e-01,
          3.6621e-01,  1.6011e-01,  1.2368e-02, -6.5463e-01, -4.1487e-02,
         -9.6702e-01, -2.8059e-01,  2.6916e-01,  6.7283e-01, -5.2141e-01,
         -2.5983e-01,  1.8001e-01,  1.0120e-01, -2.0587e-01, -4.9314e-01,
         -6.8490e-01, -1.5255e+00,  1.8295e-01, -2.5930e-01, -1.2546e+00,
         -6.9589e-02,  4.3109e-01, -2.0965e-03, -3.3771e-01,  3.5798e-01,
          9.4842e-01, -4.0370e-01, -5.9109e-01, -3.0457e-02,  8.6064e-01,
          2.0668e-01, -3.9533e-02, -1.1155e+00,  1.6056e-01, -6.5845e-01,
          8.1298e-01, -6.7716e-01,  7.3191e-01, -6.5311e-03, -8.2889e-01,
          2.2935e-01, -1.0802e+00, -2.7448e-01,  1.6502e-01,  6.5893e-01,
         -2.7223e-01, -7.5991e-02,  2.9284e-01, -5.6872e-01,  6.5480e-01,
          7.2336e-02,  1.8865e-01, -1.0025e+00, -3.8879e-01, -6.7788e-01,
         -1.2397e-01, -3.5319e-01, -2.9062e-01, -2.1113e+00, -2.8191e-01,
         -7.8004e-01,  1.7187e-01,  7.2854e-01, -1.2840e-01,  1.7805e-01,
         -9.9736e-01,  6.6500e-01, -1.1308e+00,  8.9092e-02, -7.6484e-02,
         -3.3015e-01, -8.3021e-03,  2.0718e-01,  1.5108e-01, -4.2013e-01,
          2.6676e-01, -7.4054e-02,  3.5822e-01, -5.4754e-01, -6.3731e-01,
          1.6186e-01, -3.3157e-01,  4.8395e-01, -3.2022e-01, -7.7350e-01,
          7.0298e-01,  1.9313e-01, -1.4782e+00,  1.6191e-01,  7.8461e-01,
          7.4405e-01, -4.7152e-01,  9.6684e-02,  6.4955e-01, -3.2820e-01,
          7.1990e-01,  1.3264e+00, -4.5643e-01,  3.7712e-01,  3.0739e-01,
          4.9210e-01, -8.8734e-01,  7.0449e-01,  4.2371e-01, -6.9552e-01,
         -4.3504e-01,  7.1409e-02, -3.4701e-01,  1.1066e-01, -6.0244e-01,
          4.9573e-02, -1.2310e-01, -2.2023e-01, -8.0928e-01,  1.0534e-01,
          2.8099e-01, -4.2681e-01,  4.7401e-02,  1.5718e-01,  3.3319e-01,
         -6.7182e-01,  7.1283e-01,  1.1922e+00, -9.8484e-02,  1.4709e-02,
          7.8795e-01,  3.9003e-01,  7.0788e-02, -5.5820e-01, -1.6251e-01,
          4.7181e-01, -4.2967e-01,  2.4140e-01, -1.6781e-01, -7.4418e-01,
          4.0156e-01,  1.0813e+00,  5.5052e-01, -4.2373e-01, -6.1423e-01,
         -1.7941e-01,  2.7244e-02,  4.9143e-01,  2.2138e-01, -7.6815e-01,
         -8.6145e-01,  4.8813e-01, -8.6220e-01,  4.5893e-01, -1.9820e-02,
          1.0818e-01, -4.0966e-01,  7.4174e-01, -3.6610e-01,  2.1310e-02,
          8.0221e-01, -1.8562e-01,  5.8339e-02,  4.0976e-01,  7.0982e-01,
         -3.1056e-01,  1.5611e-01,  8.4351e-02,  1.8569e-01, -1.6442e-02,
          6.1764e-01, -5.8445e-02, -3.6498e-01,  4.0751e-01, -3.0788e-01,
         -9.5406e-02,  7.1717e-01,  2.4616e-01, -3.7505e-01, -2.1185e-01,
          3.5740e-01, -4.4704e-01, -1.9303e-01, -2.2978e-02, -1.7094e-01,
         -5.1454e-01, -7.2254e-01, -7.0447e-01,  2.7950e-01,  6.4647e-01,
          7.6261e-01, -3.2582e-01,  1.0665e+00, -8.2867e-01,  3.0364e-01,
         -3.7458e-02, -7.8399e-01, -2.0923e-01, -2.5211e-01, -2.0208e-01,
          2.0298e-01, -6.1091e-01, -1.8543e-01, -5.2403e-01,  1.6136e-05,
          5.4596e-01, -9.8686e-01, -3.3052e-01,  3.0897e-01,  4.7233e-01,
          3.8098e-02,  9.4033e-01,  3.2290e-01,  8.6700e-03, -2.4415e-01,
          3.9516e+00, -4.7525e-01,  7.1882e-01,  2.5092e-01, -1.0444e-01,
          1.9712e-01,  7.1822e-01, -1.3286e-01, -3.0513e-01, -1.0075e-01,
          6.1208e-02, -3.2429e-01, -4.3805e-01,  2.2502e-01,  3.1383e-03,
         -8.6180e-01, -5.3833e-01,  2.1707e-01, -8.0440e-02, -1.5623e-01,
          1.3789e+00,  2.2651e-01, -6.3904e-01, -5.1710e-01, -6.7277e-01,
         -6.8123e-01, -4.4332e-01,  1.1069e+00,  3.0457e-01, -4.2791e-01,
         -2.5976e-01, -4.9939e-01,  8.8211e-02,  5.7401e-01, -1.2304e-01,
         -5.9881e-01,  4.9584e-01,  1.1418e-01,  3.9800e-01,  3.0608e-02,
          8.9072e-01,  2.0687e-01,  2.7497e-01,  1.6121e-01,  1.1269e+00,
         -3.1077e-01, -1.0664e+00, -1.5314e-01, -4.9027e-01,  1.9091e-01,
          3.8845e-01,  1.7292e-01, -6.2660e-01, -1.0976e-01,  5.3393e-01,
         -6.6391e-01,  5.6167e-01,  8.1336e-01, -3.8923e-04,  1.5196e-01,
         -3.2085e-01,  2.7172e+00, -5.1197e-01,  9.2552e-01, -4.9456e-01,
         -3.3745e-01,  2.8565e-01,  2.3598e-01, -1.4363e-01,  1.5846e-01,
         -8.3367e-02, -4.9820e-01,  5.0628e-01]], grad_fn=<MeanBackward1>)
torch.Size([1, 768])

需要说明的是,这里的bert中文提取是直接将一行中文分割成单字再做one-hot编码,比如这里的input是 “春眠不觉晓” ,tokenizer这个函数则会将其转变成tensor([[ 101, 3217, 4697, 679, 6230, 3236, 102]]), 其中 101, 102分别对应的是"[CLS]" "[SEP]"这两个字符,中间的五个数字才是输入中文的one-hot编码。

整个句子经过bert特征提取之后,输出的是17768维的矩阵,其实就是每个字的特征为768,最后v = torch.mean(outputs[0], dim=1) # shape (1, 768)这行代码是将整个句子所有字的特征做一个加和求平均,得到的是整个句子的特征向量

参考

BERT 模型详解
BERT模型详解
transformers

你可能感兴趣的:(bert,人工智能,深度学习)