huggingface-BertModel/BertTokenizer

1. 模块导入

from transformers import BertTokenizer, BertModel

2. 定义模型

model = BertModel.from_pretrained("bert-base-uncased",output_hidden_states=True)

参数output_hidden_states=True,模型将输出每层encoder的句子隐层。

3. tokenizer

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
tokened = tokenizer("How are you?", return_tensors='pt')

参数return_tensors='pt',返回的是pytorch tensor类型。

输出tokened如下:

{'input_ids': tensor([[ 101, 2129, 2024, 2017, 1029,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}

包含三项:input_ids, token_type_ids, attention_mask。

输入了四个token(how, are, you,?),输出的tensor是六维,因为tokenizer自动在首末加了CLS和SEP。

将input_ids解码输出:

print(tokenizer.decode( [101, 2129, 2024, 2017, 1029, 102]))

[CLS] how are you? [SEP]

4. BERT输出

output = model(input_ids=tokened['input_ids'])

 包含三个部分,

  • last_hidden_state:最后一层输出的句子的隐层状态。(用BERT做embedding层时)
  • pooler_output:最后一层CLS这一个token的隐层状态。因为CLS会更公平地融合文本中各个token的语义信息,所以通常作为整个句子的语义表征,来做下游分类任务。
  • hidden_states:BERT每一层输出的句子的隐层状态(共13层,第一层是embedding)。
# 输出last_hidden_state维度
print(output[0].shape)

# 输出pooler_output维度
print(output[1].shape)

# 输出hidden_states类型,长度
print(type(output[2]),len(output[2]))

# 输出第一个的hidden_state维度
print(output[2][0].shape)
torch.Size([1, 6, 768]) #(batchsize,max_sequence_length,hidden_size)

torch.Size([1, 768]) #(batchsize,hidden_size)

 13
torch.Size([1, 6, 768]) #(batchsize, max_sequence_length, hidden_size)

 

附:BERT完整输出 

BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-0.0573,  0.1625, -0.4028,  ..., -0.4454,  0.2880,  0.3058],
         [-0.1744, -0.4683, -0.2665,  ..., -0.2229,  0.7547, -0.4456],
         [ 0.3816, -0.6513,  0.3829,  ..., -0.7145,  0.1510, -0.7006],
         [-0.1328, -0.6548,  0.9204,  ..., -0.4545,  0.3768, -1.0013],
         [-0.0467, -0.7707, -0.9214,  ...,  0.1693,  0.3671, -0.2546],
         [ 0.6713, -0.0606, -0.2756,  ...,  0.1626, -0.3610, -0.2107]]],
       grad_fn=), pooler_output=tensor([[-0.9639, -0.4993, -0.9370,  0.9181,  0.6590, -0.2347,  0.9709,  0.3910,
         -0.8730, -1.0000, -0.6252,  0.9708,  0.9868,  0.7147,  0.9811, -0.9054,
         -0.6828, -0.7329,  0.3840, -0.8714,  0.8462,  1.0000, -0.1027,  0.4062,
          0.6092,  0.9969, -0.9039,  0.9731,  0.9827,  0.7820, -0.8712,  0.1885,
         -0.9914, -0.2279, -0.9189, -0.9961,  0.4929, -0.8633, -0.1680, -0.0422,
         -0.9448,  0.3537,  1.0000,  0.2448,  0.4656, -0.4160, -1.0000,  0.3710,
         -0.9528,  0.9701,  0.9430,  0.8852,  0.3447,  0.5873,  0.5981, -0.1670,
          0.0284,  0.1660, -0.3779, -0.6841, -0.6573,  0.6169, -0.9182, -0.9448,
          0.9527,  0.8685, -0.2310, -0.3400, -0.2363, -0.0459,  0.9648,  0.3784,
          0.0015, -0.9229,  0.7714,  0.3504, -0.8224,  1.0000, -0.8578, -0.9867,
          0.8529,  0.8783,  0.7610, -0.4921,  0.5866, -1.0000,  0.7434, -0.1500,
         -0.9944,  0.3092,  0.6802, -0.3913,  0.4249,  0.7614, -0.7716, -0.5902,
         -0.4421, -0.8968, -0.4131, -0.4213,  0.1690, -0.3365, -0.6165, -0.5304,
          0.4636, -0.5859, -0.7818,  0.4761,  0.0640,  0.7744,  0.5892, -0.4240,
          0.6625, -0.9849,  0.7946, -0.4982, -0.9932, -0.7690, -0.9928,  0.7947,
         -0.4692, -0.3429,  0.9878, -0.2252,  0.6162, -0.1899, -0.9633, -1.0000,
         -0.8550, -0.6835, -0.3601, -0.4255, -0.9867, -0.9773,  0.7461,  0.9780,
          0.3300,  1.0000, -0.4687,  0.9732, -0.4603, -0.7008,  0.7487, -0.6143,
          0.8269,  0.7566, -0.8236,  0.2973, -0.4219,  0.6640, -0.8196, -0.3078,
         -0.8190, -0.9741, -0.4802,  0.9728, -0.5953, -0.9732, -0.0724, -0.4220,
         -0.6787,  0.9119,  0.8096,  0.5509, -0.4881,  0.5235,  0.5619,  0.7361,
         -0.9376, -0.2613,  0.6226, -0.4266, -0.8814, -0.9843, -0.5397,  0.6765,
          0.9952,  0.8805,  0.4681,  0.8854, -0.3759,  0.8607, -0.9751,  0.9881,
         -0.3364,  0.3848,  0.1099,  0.4441, -0.9595,  0.1625,  0.9457, -0.7576,
         -0.9179, -0.2893, -0.5712, -0.5417, -0.8286,  0.7414, -0.4026, -0.3658,
         -0.1692,  0.9522,  0.9966,  0.9020,  0.2979,  0.8393, -0.9540, -0.6894,
          0.1825,  0.4281,  0.1870,  0.9973, -0.7403, -0.2360, -0.9683, -0.9922,
          0.0485, -0.9644, -0.2205, -0.7313,  0.7964,  0.1598,  0.7443,  0.6463,
         -0.9982, -0.8792,  0.5282, -0.6203,  0.5653, -0.3124,  0.3869,  0.9590,
         -0.7478,  0.9321,  0.9450, -0.8777, -0.8455,  0.9289, -0.4548,  0.9536,
         -0.7856,  0.9961,  0.9635,  0.9318, -0.9671, -0.7757, -0.9608, -0.8807,
         -0.1498,  0.3246,  0.9498,  0.7949,  0.5329,  0.2485, -0.8161,  0.9999,
         -0.3527, -0.9770,  0.0037, -0.3620, -0.9924,  0.9177,  0.4558,  0.2058,
         -0.5407, -0.8518, -0.9781,  0.9580,  0.2718,  0.9968, -0.2733, -0.9850,
         -0.7317, -0.9557, -0.0782, -0.4049, -0.4331, -0.1381, -0.9814,  0.6231,
          0.6514,  0.6933, -0.8871,  0.9999,  1.0000,  0.9832,  0.9480,  0.9792,
         -1.0000, -0.4907,  1.0000, -0.9972, -1.0000, -0.9637, -0.7868,  0.5556,
         -1.0000, -0.2827, -0.1662, -0.9524,  0.8380,  0.9865,  0.9993, -1.0000,
          0.8841,  0.9749, -0.8301,  0.9880, -0.5627,  0.9809,  0.6562,  0.3913,
         -0.4112,  0.4981, -0.9584, -0.9537, -0.6507, -0.7746,  0.9988,  0.3178,
         -0.8923, -0.9569,  0.4846, -0.2291, -0.2566, -0.9813, -0.2779,  0.7789,
          0.9105,  0.2398,  0.4435, -0.8275,  0.4358, -0.0067,  0.5701,  0.8237,
         -0.9564, -0.7212, -0.4226, -0.1180, -0.7685, -0.9754,  0.9858, -0.5656,
          0.9535,  1.0000,  0.1815, -0.9606,  0.7683,  0.3945, -0.2715,  1.0000,
          0.8628, -0.9854, -0.7557,  0.6681, -0.6831, -0.7360,  0.9999, -0.2509,
         -0.7722, -0.5705,  0.9849, -0.9924,  0.9971, -0.9724, -0.9825,  0.9837,
          0.9621, -0.8634, -0.7938,  0.3014, -0.8191,  0.3442, -0.9889,  0.8583,
          0.7474, -0.2749,  0.9386, -0.9649, -0.7691,  0.4106, -0.7648, -0.2074,
          0.9649,  0.6622, -0.4458,  0.1391, -0.4272, -0.2285, -0.9903,  0.4343,
          1.0000, -0.3728,  0.7500, -0.6096, -0.0782, -0.0467,  0.5787,  0.7230,
         -0.4480, -0.9231,  0.8529, -0.9942, -0.9883,  0.9108,  0.2968, -0.4507,
          1.0000,  0.6721,  0.2267,  0.4615,  0.9947,  0.1456,  0.8221,  0.9437,
          0.9908, -0.3572,  0.7612,  0.9592, -0.9562, -0.4870, -0.7562,  0.0359,
         -0.9192, -0.1009, -0.9800,  0.9813,  0.9810,  0.5379,  0.3425,  0.6109,
          1.0000, -0.3027,  0.7574, -0.7509,  0.9638, -0.9998, -0.9582, -0.5514,
         -0.2216, -0.8854, -0.4296,  0.4852, -0.9855,  0.9070,  0.7439, -0.9990,
         -0.9959, -0.2874,  0.9608,  0.2318, -0.9910, -0.8865, -0.6975,  0.7487,
         -0.4122, -0.9673, -0.3669, -0.4526,  0.6604, -0.3187,  0.7298,  0.9125,
          0.4807, -0.6056, -0.4237, -0.1390, -0.9107,  0.9518, -0.9408, -0.9633,
         -0.3652,  1.0000, -0.6232,  0.9462,  0.8387,  0.8843, -0.2297,  0.3219,
          0.9521,  0.3487, -0.8649, -0.9282, -0.9284, -0.5327,  0.8448,  0.3688,
          0.7981,  0.8851,  0.7787,  0.2426, -0.1275,  0.0627,  1.0000, -0.3206,
         -0.2520, -0.6673, -0.1260, -0.5247, -0.7049,  1.0000,  0.4109,  0.5311,
         -0.9940, -0.9166, -0.9730,  1.0000,  0.8794, -0.9183,  0.8293,  0.7381,
         -0.2460,  0.9386, -0.2837, -0.4118,  0.3727,  0.2090,  0.9740, -0.6829,
         -0.9815, -0.7682,  0.5797, -0.9837,  1.0000, -0.7271, -0.3817, -0.5546,
         -0.3283,  0.7684, -0.0143, -0.9911, -0.3802,  0.3628,  0.9782,  0.4206,
         -0.8047, -0.9602,  0.8745,  0.8764, -0.9573, -0.9632,  0.9773, -0.9949,
          0.7125,  1.0000,  0.4571,  0.2506,  0.3976, -0.6670,  0.5187, -0.3151,
          0.8834, -0.9821, -0.3898, -0.2883,  0.4084, -0.2781, -0.2061,  0.8508,
          0.2817, -0.7630, -0.7728, -0.2177,  0.5364,  0.9567, -0.3668, -0.3175,
          0.2292, -0.2774, -0.9759, -0.3111, -0.6086, -1.0000,  0.8677, -1.0000,
          0.6034,  0.4101, -0.3864,  0.9284,  0.1872,  0.6729, -0.8836, -0.9023,
          0.3494,  0.8616, -0.4054, -0.7423, -0.8551,  0.4803, -0.1593,  0.2331,
         -0.6162,  0.8185, -0.3419,  1.0000,  0.2587, -0.8902, -0.9961,  0.2707,
         -0.3756,  1.0000, -0.9811, -0.9696,  0.5954, -0.8535, -0.9130,  0.4765,
          0.1160, -0.8599, -0.9679,  0.9881,  0.9737, -0.7367,  0.5846, -0.4798,
         -0.7181,  0.1214,  0.8910,  0.9905,  0.5837,  0.9690,  0.6102, -0.1811,
          0.9854,  0.3214,  0.7951,  0.2239,  1.0000,  0.4570, -0.9499,  0.0648,
         -0.9921, -0.3552, -0.9790,  0.4716,  0.3796,  0.9339, -0.2584,  0.9850,
         -0.8884,  0.1191, -0.8238, -0.5894,  0.5726, -0.9569, -0.9890, -0.9905,
          0.6442, -0.5990, -0.1902,  0.2571,  0.2937,  0.5793,  0.5845, -1.0000,
          0.9707,  0.6047,  0.9628,  0.9769,  0.7969,  0.5512,  0.3688, -0.9935,
         -0.9973, -0.4643, -0.3199,  0.8706,  0.7476,  0.9571,  0.5946, -0.5784,
         -0.1496, -0.5475, -0.3116, -0.9961,  0.6147, -0.7207, -0.9945,  0.9781,
         -0.1213, -0.2692, -0.1053, -0.8538,  0.9879,  0.9060,  0.6091,  0.0540,
          0.5773,  0.9452,  0.9863,  0.9897, -0.8878,  0.9530, -0.7677,  0.6045,
          0.6884, -0.9590,  0.2035,  0.6452, -0.6606,  0.4414, -0.3820, -0.9940,
          0.7435, -0.4090,  0.7173, -0.5816, -0.0113, -0.5246, -0.1963, -0.8200,
         -0.8477,  0.7834,  0.7695,  0.9511,  0.7920, -0.0710, -0.8780, -0.3213,
         -0.8836, -0.9467,  0.9839, -0.1764, -0.5626,  0.6764,  0.0374,  0.7217,
          0.2637, -0.4755, -0.4506, -0.8280,  0.9425, -0.5586, -0.6535, -0.7009,
          0.8099,  0.4119,  1.0000, -0.8556, -0.9608, -0.5079, -0.4905,  0.5642,
         -0.7158, -1.0000,  0.5819, -0.5215,  0.7922, -0.8036,  0.8845, -0.8239,
         -0.9957, -0.3756,  0.5139,  0.7862, -0.5883, -0.8074,  0.7892, -0.4830,
          0.9881,  0.9518, -0.7551,  0.0959,  0.8097, -0.7260, -0.8043,  0.9666]],
       grad_fn=), hidden_states=(tensor([[[ 0.1686, -0.2858, -0.3261,  ..., -0.0276,  0.0383,  0.1640],
         [-0.0720,  0.9108, -1.1910,  ...,  0.7400,  0.8613,  0.2174],
         [-0.2330,  0.1164,  0.5087,  ...,  0.3019,  0.1804,  0.3744],
         [-0.4960,  0.2760,  0.4670,  ...,  0.7051,  0.3014,  0.3416],
         [ 0.3939, -0.2877, -0.3005,  ...,  0.4445,  0.5525,  0.6004],
         [-0.3251, -0.3188, -0.1163,  ..., -0.3960,  0.4112, -0.0776]]],
       grad_fn=), tensor([[[ 0.0388,  0.1101, -0.1576,  ...,  0.1122,  0.0586, -0.0161],
         [ 0.2020,  1.4445, -1.2794,  ...,  0.1721,  0.5129,  0.4866],
         [ 0.4994,  0.3372,  0.6230,  ..., -0.4619,  0.3822, -0.0990],
         [-0.3165,  0.2699,  0.5134,  ...,  0.4806,  0.0789,  0.2612],
         [ 0.7581, -0.0802, -0.4836,  ...,  0.1307,  0.6724,  0.8874],
         [-0.0213, -0.2086, -0.1192,  ..., -0.5538,  0.3962,  0.0161]]],
       grad_fn=), tensor([[[-0.0946, -0.2108, -0.3258,  ...,  0.2693,  0.1359,  0.0787],
         [ 0.2205,  1.9918, -0.5364,  ...,  0.0188,  0.4090,  0.2235],
         [ 0.0895,  0.3188,  1.3691,  ..., -0.5914,  0.3385, -0.6825],
         [-0.4056, -0.5321,  0.1526,  ...,  0.2154,  0.0022,  0.1891],
         [ 0.4781, -0.3290, -0.5145,  ..., -0.1450,  0.4181,  0.9846],
         [-0.1271, -0.2587,  0.1175,  ..., -0.1840,  0.3792, -0.1247]]],
       grad_fn=), tensor([[[-0.0772, -0.3340, -0.0871,  ...,  0.3118,  0.1349,  0.2388],
         [-0.1580,  1.6831,  0.1133,  ...,  0.3931,  0.2327,  0.0816],
         [-0.1316, -0.1502,  1.4529,  ..., -0.6485,  0.1748, -0.7754],
         [-0.2039, -0.8913,  0.3630,  ...,  0.2297, -0.0181,  0.0804],
         [ 0.4439, -0.6840, -0.1635,  ...,  0.2165,  0.0798,  0.3781],
         [-0.0752, -0.1274,  0.1253,  ...,  0.0260,  0.0710, -0.0180]]],
       grad_fn=), tensor([[[ 0.0306, -0.5660, -0.7076,  ...,  0.5205,  0.0904,  0.5262],
         [-0.0427,  1.2388,  0.6547,  ...,  0.4277, -0.0733, -0.5716],
         [-0.3474, -0.7610,  1.7970,  ..., -0.6507,  0.1996, -0.3475],
         [-0.1496, -1.8310,  0.5126,  ...,  0.0577,  0.3315, -0.2452],
         [ 0.9380, -1.1031, -0.5902,  ...,  0.3442,  0.1777,  0.4778],
         [-0.0280, -0.0590,  0.0084,  ...,  0.0154,  0.0535, -0.0352]]],
       grad_fn=), tensor([[[ 0.0812, -0.2125, -0.6219,  ..., -0.0835,  0.3285,  0.4436],
         [ 0.2676,  0.9164, -0.2069,  ...,  0.1127, -0.2140, -0.3537],
         [-0.5943, -1.1156,  0.8255,  ..., -0.8022,  0.1315,  0.1764],
         [ 0.1365, -1.9111,  0.4468,  ...,  0.0412, -0.1455, -0.0240],
         [ 0.5411, -0.9937, -0.8214,  ...,  0.0132,  0.0558,  0.8175],
         [-0.0152, -0.0440,  0.0131,  ...,  0.0220,  0.0078, -0.0387]]],
       grad_fn=), tensor([[[ 1.6313e-01, -7.7907e-01, -7.2191e-01,  ..., -1.3117e-01,
           3.6851e-01,  3.8299e-01],
         [ 3.6151e-01,  1.0496e+00,  1.2720e-03,  ..., -4.6335e-01,
          -6.0164e-01, -4.5743e-01],
         [-2.2870e-01, -1.1642e+00,  4.7943e-01,  ..., -1.5846e+00,
           7.5402e-01,  1.0892e-01],
         [-2.2471e-01, -1.9071e+00,  4.3491e-01,  ..., -5.4294e-01,
           3.2445e-01, -4.0690e-01],
         [-1.4225e-02, -1.0958e+00, -1.2180e+00,  ..., -3.4226e-01,
          -2.5873e-01,  5.2984e-01],
         [ 1.1444e-02, -4.0294e-02, -1.7600e-02,  ...,  6.7591e-03,
          -2.1420e-02, -4.3917e-02]]], grad_fn=), tensor([[[ 0.3211, -0.2976, -0.7657,  ..., -0.5251,  0.4389,  0.4900],
         [ 0.1868,  1.0571,  0.6403,  ..., -0.6938, -0.5096, -0.2832],
         [-0.7042, -1.1779,  0.5186,  ..., -1.3521,  0.4648, -0.3709],
         [-0.8978, -1.9789,  0.5508,  ..., -0.6249,  0.6898, -0.6074],
         [-0.2716, -1.5542, -2.0756,  ..., -0.7194, -0.1692,  0.6462],
         [-0.0143, -0.0433, -0.0149,  ..., -0.0209,  0.0118, -0.0601]]],
       grad_fn=), tensor([[[ 0.1281, -0.2184, -1.3155,  ..., -0.9203,  0.2414,  0.8508],
         [ 0.2685,  0.4445,  0.6281,  ..., -0.5510, -0.7145, -0.2859],
         [-0.3381, -1.4646,  0.3618,  ..., -1.2901,  0.0663, -0.4174],
         [-0.4894, -1.8939,  0.5440,  ..., -0.4871,  0.8371, -0.3794],
         [-0.6856, -1.2587, -1.8686,  ..., -0.8897, -0.5912,  0.4858],
         [ 0.0081, -0.0234,  0.0395,  ..., -0.0532, -0.0211, -0.0851]]],
       grad_fn=), tensor([[[-0.0527, -0.0867, -1.7301,  ..., -0.4690,  0.0828,  0.3822],
         [-0.0247,  0.3531,  0.7252,  ..., -0.4155, -0.9026, -0.6999],
         [-0.0070, -1.0945,  0.3933,  ..., -0.7901, -0.3328, -0.5838],
         [-0.2029, -1.6412,  0.6373,  ..., -0.2943,  0.3732, -0.5496],
         [-1.1925, -1.0985, -1.7700,  ..., -0.8319, -0.1459, -0.6030],
         [ 0.0325, -0.0043,  0.0355,  ..., -0.0543, -0.0607, -0.0377]]],
       grad_fn=), tensor([[[ 0.0916, -0.1851, -0.9353,  ..., -0.8346, -0.5341,  0.3569],
         [ 0.1064, -0.8527,  0.4448,  ..., -0.3772, -1.0587, -1.1242],
         [ 0.2348, -1.5743,  0.1723,  ..., -0.5256, -0.7394, -0.7555],
         [-0.0975, -1.7546,  0.6622,  ..., -0.0796,  0.1267, -0.7318],
         [-0.6475, -1.0668, -1.2231,  ..., -0.5111,  0.0027, -0.2800],
         [ 0.0115, -0.0389,  0.0561,  ...,  0.0909, -0.0160, -0.0099]]],
       grad_fn=), tensor([[[ 1.9646e-01, -2.4143e-01, -4.1810e-01,  ..., -7.7052e-01,
           1.2252e-01, -4.2691e-03],
         [-3.0136e-01, -1.5718e-01,  1.8951e-01,  ..., -9.8252e-01,
          -1.9064e-01, -1.2972e+00],
         [-1.4842e-01, -6.8893e-01,  3.6466e-01,  ..., -1.1063e+00,
          -2.7100e-02, -9.8203e-01],
         [-3.7955e-01, -1.2046e+00,  7.2221e-01,  ..., -7.6463e-01,
           6.0633e-01, -7.2070e-01],
         [-3.6614e-01, -1.2337e+00, -9.5419e-01,  ..., -4.5989e-01,
           3.6219e-01, -6.3111e-02],
         [ 4.0417e-02,  3.7746e-04, -3.1245e-02,  ...,  1.5826e-02,
          -2.6091e-02,  2.6609e-02]]], grad_fn=), tensor([[[-0.0573,  0.1625, -0.4028,  ..., -0.4454,  0.2880,  0.3058],
         [-0.1744, -0.4683, -0.2665,  ..., -0.2229,  0.7547, -0.4456],
         [ 0.3816, -0.6513,  0.3829,  ..., -0.7145,  0.1510, -0.7006],
         [-0.1328, -0.6548,  0.9204,  ..., -0.4545,  0.3768, -1.0013],
         [-0.0467, -0.7707, -0.9214,  ...,  0.1693,  0.3671, -0.2546],
         [ 0.6713, -0.0606, -0.2756,  ...,  0.1626, -0.3610, -0.2107]]],
       grad_fn=)), past_key_values=None, attentions=None, cross_attentions=None)

你可能感兴趣的:(python,神经网络)