from transformers import BertTokenizer, BertModel
model = BertModel.from_pretrained("bert-base-uncased",output_hidden_states=True)
参数output_hidden_states=True,模型将输出每层encoder的句子隐层。
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
tokened = tokenizer("How are you?", return_tensors='pt')
参数return_tensors='pt',返回的是pytorch tensor类型。
输出tokened如下:
{'input_ids': tensor([[ 101, 2129, 2024, 2017, 1029, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}
包含三项:input_ids, token_type_ids, attention_mask。
输入了四个token(how, are, you,?),输出的tensor是六维,因为tokenizer自动在首末加了CLS和SEP。
将input_ids解码输出:
print(tokenizer.decode( [101, 2129, 2024, 2017, 1029, 102]))
[CLS] how are you? [SEP]
output = model(input_ids=tokened['input_ids'])
包含三个部分,
# 输出last_hidden_state维度
print(output[0].shape)
# 输出pooler_output维度
print(output[1].shape)
# 输出hidden_states类型,长度
print(type(output[2]),len(output[2]))
# 输出第一个的hidden_state维度
print(output[2][0].shape)
torch.Size([1, 6, 768]) #(batchsize,max_sequence_length,hidden_size)
torch.Size([1, 768]) #(batchsize,hidden_size)
13
torch.Size([1, 6, 768]) #(batchsize, max_sequence_length, hidden_size)
附:BERT完整输出
BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-0.0573, 0.1625, -0.4028, ..., -0.4454, 0.2880, 0.3058],
[-0.1744, -0.4683, -0.2665, ..., -0.2229, 0.7547, -0.4456],
[ 0.3816, -0.6513, 0.3829, ..., -0.7145, 0.1510, -0.7006],
[-0.1328, -0.6548, 0.9204, ..., -0.4545, 0.3768, -1.0013],
[-0.0467, -0.7707, -0.9214, ..., 0.1693, 0.3671, -0.2546],
[ 0.6713, -0.0606, -0.2756, ..., 0.1626, -0.3610, -0.2107]]],
grad_fn=), pooler_output=tensor([[-0.9639, -0.4993, -0.9370, 0.9181, 0.6590, -0.2347, 0.9709, 0.3910,
-0.8730, -1.0000, -0.6252, 0.9708, 0.9868, 0.7147, 0.9811, -0.9054,
-0.6828, -0.7329, 0.3840, -0.8714, 0.8462, 1.0000, -0.1027, 0.4062,
0.6092, 0.9969, -0.9039, 0.9731, 0.9827, 0.7820, -0.8712, 0.1885,
-0.9914, -0.2279, -0.9189, -0.9961, 0.4929, -0.8633, -0.1680, -0.0422,
-0.9448, 0.3537, 1.0000, 0.2448, 0.4656, -0.4160, -1.0000, 0.3710,
-0.9528, 0.9701, 0.9430, 0.8852, 0.3447, 0.5873, 0.5981, -0.1670,
0.0284, 0.1660, -0.3779, -0.6841, -0.6573, 0.6169, -0.9182, -0.9448,
0.9527, 0.8685, -0.2310, -0.3400, -0.2363, -0.0459, 0.9648, 0.3784,
0.0015, -0.9229, 0.7714, 0.3504, -0.8224, 1.0000, -0.8578, -0.9867,
0.8529, 0.8783, 0.7610, -0.4921, 0.5866, -1.0000, 0.7434, -0.1500,
-0.9944, 0.3092, 0.6802, -0.3913, 0.4249, 0.7614, -0.7716, -0.5902,
-0.4421, -0.8968, -0.4131, -0.4213, 0.1690, -0.3365, -0.6165, -0.5304,
0.4636, -0.5859, -0.7818, 0.4761, 0.0640, 0.7744, 0.5892, -0.4240,
0.6625, -0.9849, 0.7946, -0.4982, -0.9932, -0.7690, -0.9928, 0.7947,
-0.4692, -0.3429, 0.9878, -0.2252, 0.6162, -0.1899, -0.9633, -1.0000,
-0.8550, -0.6835, -0.3601, -0.4255, -0.9867, -0.9773, 0.7461, 0.9780,
0.3300, 1.0000, -0.4687, 0.9732, -0.4603, -0.7008, 0.7487, -0.6143,
0.8269, 0.7566, -0.8236, 0.2973, -0.4219, 0.6640, -0.8196, -0.3078,
-0.8190, -0.9741, -0.4802, 0.9728, -0.5953, -0.9732, -0.0724, -0.4220,
-0.6787, 0.9119, 0.8096, 0.5509, -0.4881, 0.5235, 0.5619, 0.7361,
-0.9376, -0.2613, 0.6226, -0.4266, -0.8814, -0.9843, -0.5397, 0.6765,
0.9952, 0.8805, 0.4681, 0.8854, -0.3759, 0.8607, -0.9751, 0.9881,
-0.3364, 0.3848, 0.1099, 0.4441, -0.9595, 0.1625, 0.9457, -0.7576,
-0.9179, -0.2893, -0.5712, -0.5417, -0.8286, 0.7414, -0.4026, -0.3658,
-0.1692, 0.9522, 0.9966, 0.9020, 0.2979, 0.8393, -0.9540, -0.6894,
0.1825, 0.4281, 0.1870, 0.9973, -0.7403, -0.2360, -0.9683, -0.9922,
0.0485, -0.9644, -0.2205, -0.7313, 0.7964, 0.1598, 0.7443, 0.6463,
-0.9982, -0.8792, 0.5282, -0.6203, 0.5653, -0.3124, 0.3869, 0.9590,
-0.7478, 0.9321, 0.9450, -0.8777, -0.8455, 0.9289, -0.4548, 0.9536,
-0.7856, 0.9961, 0.9635, 0.9318, -0.9671, -0.7757, -0.9608, -0.8807,
-0.1498, 0.3246, 0.9498, 0.7949, 0.5329, 0.2485, -0.8161, 0.9999,
-0.3527, -0.9770, 0.0037, -0.3620, -0.9924, 0.9177, 0.4558, 0.2058,
-0.5407, -0.8518, -0.9781, 0.9580, 0.2718, 0.9968, -0.2733, -0.9850,
-0.7317, -0.9557, -0.0782, -0.4049, -0.4331, -0.1381, -0.9814, 0.6231,
0.6514, 0.6933, -0.8871, 0.9999, 1.0000, 0.9832, 0.9480, 0.9792,
-1.0000, -0.4907, 1.0000, -0.9972, -1.0000, -0.9637, -0.7868, 0.5556,
-1.0000, -0.2827, -0.1662, -0.9524, 0.8380, 0.9865, 0.9993, -1.0000,
0.8841, 0.9749, -0.8301, 0.9880, -0.5627, 0.9809, 0.6562, 0.3913,
-0.4112, 0.4981, -0.9584, -0.9537, -0.6507, -0.7746, 0.9988, 0.3178,
-0.8923, -0.9569, 0.4846, -0.2291, -0.2566, -0.9813, -0.2779, 0.7789,
0.9105, 0.2398, 0.4435, -0.8275, 0.4358, -0.0067, 0.5701, 0.8237,
-0.9564, -0.7212, -0.4226, -0.1180, -0.7685, -0.9754, 0.9858, -0.5656,
0.9535, 1.0000, 0.1815, -0.9606, 0.7683, 0.3945, -0.2715, 1.0000,
0.8628, -0.9854, -0.7557, 0.6681, -0.6831, -0.7360, 0.9999, -0.2509,
-0.7722, -0.5705, 0.9849, -0.9924, 0.9971, -0.9724, -0.9825, 0.9837,
0.9621, -0.8634, -0.7938, 0.3014, -0.8191, 0.3442, -0.9889, 0.8583,
0.7474, -0.2749, 0.9386, -0.9649, -0.7691, 0.4106, -0.7648, -0.2074,
0.9649, 0.6622, -0.4458, 0.1391, -0.4272, -0.2285, -0.9903, 0.4343,
1.0000, -0.3728, 0.7500, -0.6096, -0.0782, -0.0467, 0.5787, 0.7230,
-0.4480, -0.9231, 0.8529, -0.9942, -0.9883, 0.9108, 0.2968, -0.4507,
1.0000, 0.6721, 0.2267, 0.4615, 0.9947, 0.1456, 0.8221, 0.9437,
0.9908, -0.3572, 0.7612, 0.9592, -0.9562, -0.4870, -0.7562, 0.0359,
-0.9192, -0.1009, -0.9800, 0.9813, 0.9810, 0.5379, 0.3425, 0.6109,
1.0000, -0.3027, 0.7574, -0.7509, 0.9638, -0.9998, -0.9582, -0.5514,
-0.2216, -0.8854, -0.4296, 0.4852, -0.9855, 0.9070, 0.7439, -0.9990,
-0.9959, -0.2874, 0.9608, 0.2318, -0.9910, -0.8865, -0.6975, 0.7487,
-0.4122, -0.9673, -0.3669, -0.4526, 0.6604, -0.3187, 0.7298, 0.9125,
0.4807, -0.6056, -0.4237, -0.1390, -0.9107, 0.9518, -0.9408, -0.9633,
-0.3652, 1.0000, -0.6232, 0.9462, 0.8387, 0.8843, -0.2297, 0.3219,
0.9521, 0.3487, -0.8649, -0.9282, -0.9284, -0.5327, 0.8448, 0.3688,
0.7981, 0.8851, 0.7787, 0.2426, -0.1275, 0.0627, 1.0000, -0.3206,
-0.2520, -0.6673, -0.1260, -0.5247, -0.7049, 1.0000, 0.4109, 0.5311,
-0.9940, -0.9166, -0.9730, 1.0000, 0.8794, -0.9183, 0.8293, 0.7381,
-0.2460, 0.9386, -0.2837, -0.4118, 0.3727, 0.2090, 0.9740, -0.6829,
-0.9815, -0.7682, 0.5797, -0.9837, 1.0000, -0.7271, -0.3817, -0.5546,
-0.3283, 0.7684, -0.0143, -0.9911, -0.3802, 0.3628, 0.9782, 0.4206,
-0.8047, -0.9602, 0.8745, 0.8764, -0.9573, -0.9632, 0.9773, -0.9949,
0.7125, 1.0000, 0.4571, 0.2506, 0.3976, -0.6670, 0.5187, -0.3151,
0.8834, -0.9821, -0.3898, -0.2883, 0.4084, -0.2781, -0.2061, 0.8508,
0.2817, -0.7630, -0.7728, -0.2177, 0.5364, 0.9567, -0.3668, -0.3175,
0.2292, -0.2774, -0.9759, -0.3111, -0.6086, -1.0000, 0.8677, -1.0000,
0.6034, 0.4101, -0.3864, 0.9284, 0.1872, 0.6729, -0.8836, -0.9023,
0.3494, 0.8616, -0.4054, -0.7423, -0.8551, 0.4803, -0.1593, 0.2331,
-0.6162, 0.8185, -0.3419, 1.0000, 0.2587, -0.8902, -0.9961, 0.2707,
-0.3756, 1.0000, -0.9811, -0.9696, 0.5954, -0.8535, -0.9130, 0.4765,
0.1160, -0.8599, -0.9679, 0.9881, 0.9737, -0.7367, 0.5846, -0.4798,
-0.7181, 0.1214, 0.8910, 0.9905, 0.5837, 0.9690, 0.6102, -0.1811,
0.9854, 0.3214, 0.7951, 0.2239, 1.0000, 0.4570, -0.9499, 0.0648,
-0.9921, -0.3552, -0.9790, 0.4716, 0.3796, 0.9339, -0.2584, 0.9850,
-0.8884, 0.1191, -0.8238, -0.5894, 0.5726, -0.9569, -0.9890, -0.9905,
0.6442, -0.5990, -0.1902, 0.2571, 0.2937, 0.5793, 0.5845, -1.0000,
0.9707, 0.6047, 0.9628, 0.9769, 0.7969, 0.5512, 0.3688, -0.9935,
-0.9973, -0.4643, -0.3199, 0.8706, 0.7476, 0.9571, 0.5946, -0.5784,
-0.1496, -0.5475, -0.3116, -0.9961, 0.6147, -0.7207, -0.9945, 0.9781,
-0.1213, -0.2692, -0.1053, -0.8538, 0.9879, 0.9060, 0.6091, 0.0540,
0.5773, 0.9452, 0.9863, 0.9897, -0.8878, 0.9530, -0.7677, 0.6045,
0.6884, -0.9590, 0.2035, 0.6452, -0.6606, 0.4414, -0.3820, -0.9940,
0.7435, -0.4090, 0.7173, -0.5816, -0.0113, -0.5246, -0.1963, -0.8200,
-0.8477, 0.7834, 0.7695, 0.9511, 0.7920, -0.0710, -0.8780, -0.3213,
-0.8836, -0.9467, 0.9839, -0.1764, -0.5626, 0.6764, 0.0374, 0.7217,
0.2637, -0.4755, -0.4506, -0.8280, 0.9425, -0.5586, -0.6535, -0.7009,
0.8099, 0.4119, 1.0000, -0.8556, -0.9608, -0.5079, -0.4905, 0.5642,
-0.7158, -1.0000, 0.5819, -0.5215, 0.7922, -0.8036, 0.8845, -0.8239,
-0.9957, -0.3756, 0.5139, 0.7862, -0.5883, -0.8074, 0.7892, -0.4830,
0.9881, 0.9518, -0.7551, 0.0959, 0.8097, -0.7260, -0.8043, 0.9666]],
grad_fn=), hidden_states=(tensor([[[ 0.1686, -0.2858, -0.3261, ..., -0.0276, 0.0383, 0.1640],
[-0.0720, 0.9108, -1.1910, ..., 0.7400, 0.8613, 0.2174],
[-0.2330, 0.1164, 0.5087, ..., 0.3019, 0.1804, 0.3744],
[-0.4960, 0.2760, 0.4670, ..., 0.7051, 0.3014, 0.3416],
[ 0.3939, -0.2877, -0.3005, ..., 0.4445, 0.5525, 0.6004],
[-0.3251, -0.3188, -0.1163, ..., -0.3960, 0.4112, -0.0776]]],
grad_fn=), tensor([[[ 0.0388, 0.1101, -0.1576, ..., 0.1122, 0.0586, -0.0161],
[ 0.2020, 1.4445, -1.2794, ..., 0.1721, 0.5129, 0.4866],
[ 0.4994, 0.3372, 0.6230, ..., -0.4619, 0.3822, -0.0990],
[-0.3165, 0.2699, 0.5134, ..., 0.4806, 0.0789, 0.2612],
[ 0.7581, -0.0802, -0.4836, ..., 0.1307, 0.6724, 0.8874],
[-0.0213, -0.2086, -0.1192, ..., -0.5538, 0.3962, 0.0161]]],
grad_fn=), tensor([[[-0.0946, -0.2108, -0.3258, ..., 0.2693, 0.1359, 0.0787],
[ 0.2205, 1.9918, -0.5364, ..., 0.0188, 0.4090, 0.2235],
[ 0.0895, 0.3188, 1.3691, ..., -0.5914, 0.3385, -0.6825],
[-0.4056, -0.5321, 0.1526, ..., 0.2154, 0.0022, 0.1891],
[ 0.4781, -0.3290, -0.5145, ..., -0.1450, 0.4181, 0.9846],
[-0.1271, -0.2587, 0.1175, ..., -0.1840, 0.3792, -0.1247]]],
grad_fn=), tensor([[[-0.0772, -0.3340, -0.0871, ..., 0.3118, 0.1349, 0.2388],
[-0.1580, 1.6831, 0.1133, ..., 0.3931, 0.2327, 0.0816],
[-0.1316, -0.1502, 1.4529, ..., -0.6485, 0.1748, -0.7754],
[-0.2039, -0.8913, 0.3630, ..., 0.2297, -0.0181, 0.0804],
[ 0.4439, -0.6840, -0.1635, ..., 0.2165, 0.0798, 0.3781],
[-0.0752, -0.1274, 0.1253, ..., 0.0260, 0.0710, -0.0180]]],
grad_fn=), tensor([[[ 0.0306, -0.5660, -0.7076, ..., 0.5205, 0.0904, 0.5262],
[-0.0427, 1.2388, 0.6547, ..., 0.4277, -0.0733, -0.5716],
[-0.3474, -0.7610, 1.7970, ..., -0.6507, 0.1996, -0.3475],
[-0.1496, -1.8310, 0.5126, ..., 0.0577, 0.3315, -0.2452],
[ 0.9380, -1.1031, -0.5902, ..., 0.3442, 0.1777, 0.4778],
[-0.0280, -0.0590, 0.0084, ..., 0.0154, 0.0535, -0.0352]]],
grad_fn=), tensor([[[ 0.0812, -0.2125, -0.6219, ..., -0.0835, 0.3285, 0.4436],
[ 0.2676, 0.9164, -0.2069, ..., 0.1127, -0.2140, -0.3537],
[-0.5943, -1.1156, 0.8255, ..., -0.8022, 0.1315, 0.1764],
[ 0.1365, -1.9111, 0.4468, ..., 0.0412, -0.1455, -0.0240],
[ 0.5411, -0.9937, -0.8214, ..., 0.0132, 0.0558, 0.8175],
[-0.0152, -0.0440, 0.0131, ..., 0.0220, 0.0078, -0.0387]]],
grad_fn=), tensor([[[ 1.6313e-01, -7.7907e-01, -7.2191e-01, ..., -1.3117e-01,
3.6851e-01, 3.8299e-01],
[ 3.6151e-01, 1.0496e+00, 1.2720e-03, ..., -4.6335e-01,
-6.0164e-01, -4.5743e-01],
[-2.2870e-01, -1.1642e+00, 4.7943e-01, ..., -1.5846e+00,
7.5402e-01, 1.0892e-01],
[-2.2471e-01, -1.9071e+00, 4.3491e-01, ..., -5.4294e-01,
3.2445e-01, -4.0690e-01],
[-1.4225e-02, -1.0958e+00, -1.2180e+00, ..., -3.4226e-01,
-2.5873e-01, 5.2984e-01],
[ 1.1444e-02, -4.0294e-02, -1.7600e-02, ..., 6.7591e-03,
-2.1420e-02, -4.3917e-02]]], grad_fn=), tensor([[[ 0.3211, -0.2976, -0.7657, ..., -0.5251, 0.4389, 0.4900],
[ 0.1868, 1.0571, 0.6403, ..., -0.6938, -0.5096, -0.2832],
[-0.7042, -1.1779, 0.5186, ..., -1.3521, 0.4648, -0.3709],
[-0.8978, -1.9789, 0.5508, ..., -0.6249, 0.6898, -0.6074],
[-0.2716, -1.5542, -2.0756, ..., -0.7194, -0.1692, 0.6462],
[-0.0143, -0.0433, -0.0149, ..., -0.0209, 0.0118, -0.0601]]],
grad_fn=), tensor([[[ 0.1281, -0.2184, -1.3155, ..., -0.9203, 0.2414, 0.8508],
[ 0.2685, 0.4445, 0.6281, ..., -0.5510, -0.7145, -0.2859],
[-0.3381, -1.4646, 0.3618, ..., -1.2901, 0.0663, -0.4174],
[-0.4894, -1.8939, 0.5440, ..., -0.4871, 0.8371, -0.3794],
[-0.6856, -1.2587, -1.8686, ..., -0.8897, -0.5912, 0.4858],
[ 0.0081, -0.0234, 0.0395, ..., -0.0532, -0.0211, -0.0851]]],
grad_fn=), tensor([[[-0.0527, -0.0867, -1.7301, ..., -0.4690, 0.0828, 0.3822],
[-0.0247, 0.3531, 0.7252, ..., -0.4155, -0.9026, -0.6999],
[-0.0070, -1.0945, 0.3933, ..., -0.7901, -0.3328, -0.5838],
[-0.2029, -1.6412, 0.6373, ..., -0.2943, 0.3732, -0.5496],
[-1.1925, -1.0985, -1.7700, ..., -0.8319, -0.1459, -0.6030],
[ 0.0325, -0.0043, 0.0355, ..., -0.0543, -0.0607, -0.0377]]],
grad_fn=), tensor([[[ 0.0916, -0.1851, -0.9353, ..., -0.8346, -0.5341, 0.3569],
[ 0.1064, -0.8527, 0.4448, ..., -0.3772, -1.0587, -1.1242],
[ 0.2348, -1.5743, 0.1723, ..., -0.5256, -0.7394, -0.7555],
[-0.0975, -1.7546, 0.6622, ..., -0.0796, 0.1267, -0.7318],
[-0.6475, -1.0668, -1.2231, ..., -0.5111, 0.0027, -0.2800],
[ 0.0115, -0.0389, 0.0561, ..., 0.0909, -0.0160, -0.0099]]],
grad_fn=), tensor([[[ 1.9646e-01, -2.4143e-01, -4.1810e-01, ..., -7.7052e-01,
1.2252e-01, -4.2691e-03],
[-3.0136e-01, -1.5718e-01, 1.8951e-01, ..., -9.8252e-01,
-1.9064e-01, -1.2972e+00],
[-1.4842e-01, -6.8893e-01, 3.6466e-01, ..., -1.1063e+00,
-2.7100e-02, -9.8203e-01],
[-3.7955e-01, -1.2046e+00, 7.2221e-01, ..., -7.6463e-01,
6.0633e-01, -7.2070e-01],
[-3.6614e-01, -1.2337e+00, -9.5419e-01, ..., -4.5989e-01,
3.6219e-01, -6.3111e-02],
[ 4.0417e-02, 3.7746e-04, -3.1245e-02, ..., 1.5826e-02,
-2.6091e-02, 2.6609e-02]]], grad_fn=), tensor([[[-0.0573, 0.1625, -0.4028, ..., -0.4454, 0.2880, 0.3058],
[-0.1744, -0.4683, -0.2665, ..., -0.2229, 0.7547, -0.4456],
[ 0.3816, -0.6513, 0.3829, ..., -0.7145, 0.1510, -0.7006],
[-0.1328, -0.6548, 0.9204, ..., -0.4545, 0.3768, -1.0013],
[-0.0467, -0.7707, -0.9214, ..., 0.1693, 0.3671, -0.2546],
[ 0.6713, -0.0606, -0.2756, ..., 0.1626, -0.3610, -0.2107]]],
grad_fn=)), past_key_values=None, attentions=None, cross_attentions=None)