RNN:中文称循环神经网络,一般以序列数据为输入,通过网络内部结构设计有效捕捉序列之间的关系特征,一般也是以序列形式进行输出。
RNN单层网络结构:
RNN的循环机制使模型隐层上一时间步产生的结果能够作为当下时间步输入的一部分。
因为RNN结构能够很好的利用序列间的关系,所以针对自然界有连续性的输入如人类的语言,语音等有很好的处理,广泛应用于NLP领域的各项任务,如文本分类、情感分析、意图识别、机器翻译等。
假设用户输入了What time is it? 下面是RNN处理的方式:
[x(t), h(t-1)]
,这个张量将通过一个全连接层(线性层),使用tanh双曲正切作为激活函数,最终得到该时间步的输出h(t),它将作为下一时间步的输入和x(t+1)一起进入结构体。h t = t a n h ( W t [ X t , h t − 1 ] + b t ) h_t = tanh(W_t[X_t, h_{t-1}] + b_t) ht=tanh(Wt[Xt,ht−1]+bt)
激活函数tanh:用于帮助调节流经神经网络的值,tanh函数将值压缩在-1和1之间。
RNN类在torch.nn.RNN中,其初始化主要参数解释:
'tanh'
or 'relu'
. Default: 'tanh'
RNN类实例化对象主要参数解释:
import torch
import torch.nn as nn
rnn = nn.RNN(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
output, hn = rnn(input, h0)
print(output)
print(output.shape)
print(hn)
print(hn.shape)
output:
tensor([[[-0.1823, 0.0859, -0.3405, -0.5000, -0.6306, 0.5065, 0.8202,
-0.2139, 0.5886, -0.1549, 0.3035, 0.3669, -0.3702, -0.0026,
0.0604, 0.1055, -0.0163, -0.2904, 0.2216, 0.0020],
[ 0.2791, -0.1390, 0.3652, 0.0539, -0.5179, 0.7433, -0.0418,
0.8043, 0.5498, -0.0131, 0.4987, 0.8964, 0.0033, 0.1708,
-0.0594, -0.0106, -0.5742, 0.5557, -0.3524, -0.3199],
[-0.0448, 0.2398, -0.1254, 0.5049, 0.6504, 0.6963, 0.5681,
0.5640, 0.2442, -0.6644, 0.2833, 0.7397, 0.0966, 0.4050,
0.5397, 0.1153, -0.5372, -0.4970, 0.0586, -0.1714]],
[[ 0.5013, 0.1563, 0.1514, 0.1719, -0.3103, -0.4294, -0.6875,
0.0665, -0.4604, 0.1708, 0.1925, 0.0077, -0.2452, -0.1904,
0.4462, 0.0012, -0.2967, -0.5996, -0.0416, 0.0766],
[ 0.6177, 0.4556, 0.3853, 0.2834, 0.3121, -0.1427, -0.4408,
-0.1028, -0.7400, 0.2298, -0.5990, 0.4145, 0.1973, 0.1061,
0.3418, -0.1150, -0.2209, -0.5048, 0.0269, 0.4954],
[ 0.1053, 0.3156, 0.2890, 0.2079, 0.0477, 0.2353, -0.0389,
-0.0014, -0.2171, -0.0972, 0.0658, 0.4972, 0.2478, 0.0355,
0.4458, 0.0405, -0.5211, -0.2562, -0.1064, 0.5259]],
[[-0.4234, 0.1803, 0.1560, 0.4580, -0.2345, -0.2388, 0.3107,
-0.0058, 0.0634, 0.0977, -0.4543, 0.0582, -0.0860, 0.2199,
0.1864, -0.5531, 0.5284, -0.2800, -0.0510, 0.0912],
[-0.5156, 0.4014, 0.0628, 0.3032, -0.0117, -0.1661, 0.5899,
0.1559, 0.2996, -0.4454, 0.0348, 0.0651, -0.5742, 0.2271,
0.1080, -0.3659, 0.6118, -0.4189, 0.0549, -0.0393],
[-0.0868, 0.5991, 0.1813, 0.5599, 0.3917, -0.3454, 0.0961,
0.0566, 0.0284, -0.3377, 0.0170, -0.1184, -0.5352, 0.3805,
0.1599, -0.1647, 0.2100, -0.5550, 0.1266, 0.2302]],
[[-0.3534, 0.1374, 0.1209, 0.0387, -0.1049, -0.2417, 0.1742,
-0.2224, 0.5119, -0.6369, 0.3746, 0.4883, -0.1907, 0.3288,
0.1200, 0.0569, 0.0759, -0.1567, 0.3188, 0.2419],
[ 0.0486, 0.3413, 0.1351, 0.5912, -0.3284, -0.3300, -0.0787,
0.1665, 0.1738, -0.2786, 0.3029, 0.0880, 0.3581, -0.0811,
0.4021, -0.3304, -0.2823, -0.2832, 0.1019, 0.3242],
[-0.1608, 0.1246, 0.0863, 0.3260, -0.2099, 0.2095, 0.4521,
0.4346, 0.3898, -0.4924, 0.2472, 0.2306, 0.3713, -0.0955,
0.3075, -0.2875, -0.1641, -0.2343, -0.1563, 0.3321]],
[[ 0.0825, 0.1601, 0.3947, 0.4077, 0.2677, -0.4913, -0.3839,
-0.1458, -0.0360, 0.2327, 0.0738, 0.2608, 0.3539, -0.0649,
0.2555, -0.2991, -0.2076, -0.2232, -0.2227, 0.1667],
[ 0.1509, 0.4328, 0.1422, 0.5717, -0.1864, -0.0670, 0.1437,
0.2307, -0.2267, 0.0637, -0.2936, 0.2159, 0.4971, 0.0890,
0.2068, -0.5263, 0.3438, -0.2558, 0.1983, 0.4059],
[ 0.2452, 0.6272, 0.1001, 0.2857, 0.2405, -0.3052, -0.0981,
0.1137, -0.1352, -0.3110, 0.1155, 0.0432, -0.0898, -0.0357,
0.1601, -0.0056, -0.1636, -0.4776, 0.1329, 0.3925]]],
grad_fn=<StackBackward>)
torch.Size([5, 3, 20])
hn:
tensor([[[ 0.4038, 0.0203, 0.0424, 0.3863, -0.0233, -0.4767, -0.2328,
0.6005, -0.1970, -0.1546, 0.1492, -0.4493, -0.8081, 0.7578,
-0.6790, 0.1135, -0.1818, -0.0088, -0.2488, -0.1144],
[-0.1782, -0.4026, 0.4985, -0.5878, 0.4833, 0.5260, 0.0710,
0.5741, 0.3153, -0.0460, 0.1516, 0.3593, -0.7491, 0.4448,
-0.6297, 0.2588, 0.4649, -0.2219, -0.5977, 0.3895],
[-0.1429, -0.0431, 0.7642, -0.3143, 0.4679, 0.1455, 0.3204,
-0.2070, -0.1016, -0.4045, -0.3219, -0.2693, -0.6370, -0.0010,
-0.5872, -0.5141, 0.0144, -0.4947, 0.5004, -0.5219]],
[[ 0.0825, 0.1601, 0.3947, 0.4077, 0.2677, -0.4913, -0.3839,
-0.1458, -0.0360, 0.2327, 0.0738, 0.2608, 0.3539, -0.0649,
0.2555, -0.2991, -0.2076, -0.2232, -0.2227, 0.1667],
[ 0.1509, 0.4328, 0.1422, 0.5717, -0.1864, -0.0670, 0.1437,
0.2307, -0.2267, 0.0637, -0.2936, 0.2159, 0.4971, 0.0890,
0.2068, -0.5263, 0.3438, -0.2558, 0.1983, 0.4059],
[ 0.2452, 0.6272, 0.1001, 0.2857, 0.2405, -0.3052, -0.0981,
0.1137, -0.1352, -0.3110, 0.1155, 0.0432, -0.0898, -0.0357,
0.1601, -0.0056, -0.1636, -0.4776, 0.1329, 0.3925]]],
grad_fn=<StackBackward>)
torch.Size([2, 3, 20])
优势:
劣势:
什么是梯度消失或爆炸:
根据反向传播算法及链式法则,梯度的计算可以简化为以下公式
D n = σ ′ ( z 1 ) w 1 ⋅ σ ′ ( z 2 ) w 2 ⋅ . . . ⋅ σ ′ ( z n ) w n D_n = σ'(z_1)w_1 · σ'(z_2)w_2 · ... ·σ'(z_n)w_n Dn=σ′(z1)w1⋅σ′(z2)w2⋅...⋅σ′(zn)wn
其中sigmoid的导数值是固定的,在[0, 0.25]之间,而一旦公式中w也小于1,那么通过这样的公式连乘之后,最终的梯度会变得非常非常小,这种现象称作梯度消失,反正,如果认为增大w的值,使其大于1,最终可能造成梯度过大,称作梯度爆炸。
如果梯度消失,权重将无法被更新,最终导致训练失败。梯度爆炸带来的梯度过大,大幅更新网络参数,在极端情况下结果会溢出(NaN)
类初始化主要参数解释:
x
输入张量x中特征维度大小h
隐层张量h中特征维度大小True
, becomes a bidirectional LSTM. Default: False
是否选择使用双向LSTM,默认不使用。类实例化对象主要参数解释:
输出:
import torch.nn as nn
import torch
rnn = nn.LSTM(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
output, (hn, cn) = rnn(input, (h0, c0))
print(output)
print(output.shape)
print(hn)
print(hn.shape)
print(cn)
print(cn.shape)
tensor([[[ 6.4420e-02, 6.2410e-02, -3.1918e-01, -1.7109e-01, -8.8068e-02,
2.2602e-01, -3.2335e-01, -7.1837e-03, 8.5058e-02, 5.9582e-03,
-2.9552e-01, -1.1200e-01, -9.9346e-02, -6.7507e-02, 3.3893e-01,
2.0244e-01, -3.9047e-01, -5.2756e-02, 1.4811e-01, 1.9268e-01],
[ 4.9299e-02, -8.9591e-02, 1.9591e-01, 1.7071e-01, 4.0422e-01,
-3.1363e-01, -2.9693e-01, 1.2103e-01, 1.0956e-02, 7.6901e-02,
2.3720e-01, 6.3182e-03, -3.2968e-01, -1.7734e-01, 1.4216e-01,
-2.0142e-01, -3.4723e-01, 8.3657e-02, 1.3796e-01, -2.2886e-01],
[-2.0725e-01, 2.0795e-01, 4.6395e-02, 1.3417e-02, -8.3942e-02,
8.8684e-02, -1.4023e-01, -4.5434e-02, 3.4325e-01, 5.9968e-02,
-1.6269e-01, -5.3913e-03, 3.2542e-02, -5.8951e-01, -1.8863e-01,
-2.8661e-03, -2.8178e-01, -2.7340e-01, 3.1375e-01, -1.5382e-01]],
[[-1.6560e-02, 2.0795e-02, -1.2222e-01, -9.0507e-02, 1.0746e-03,
1.1501e-01, -2.1268e-01, -2.4595e-02, 8.3267e-02, 6.6421e-02,
-9.1167e-02, -6.3106e-02, -2.9599e-02, -1.0097e-01, 1.0818e-01,
1.6963e-01, -2.5700e-01, -6.5596e-02, 4.6941e-02, 2.0623e-01],
[ 4.0872e-02, 4.0869e-03, 2.3773e-02, 1.2502e-01, 2.7042e-01,
-8.8856e-02, -1.0137e-01, 1.1388e-02, 5.1466e-02, 6.7936e-02,
1.7766e-01, 1.6299e-02, -1.0510e-01, -1.6442e-01, 1.3268e-01,
-1.3334e-02, -2.8878e-01, 4.8023e-02, -3.1834e-03, -1.1887e-01],
[-1.0880e-01, 1.4484e-01, 1.1940e-02, -7.1890e-02, -6.4775e-03,
1.3693e-01, -1.0062e-01, -5.6859e-02, 2.5171e-01, 2.6230e-02,
-1.9837e-02, 1.8749e-02, 2.6209e-02, -3.7243e-01, -1.3988e-01,
3.8677e-02, -1.3416e-01, -2.4021e-01, 8.4407e-02, -1.2588e-01]],
[[-3.1532e-02, 2.4761e-03, -5.6999e-02, -8.2566e-02, 1.0148e-01,
4.1854e-02, -1.3071e-01, -2.1285e-02, 1.2221e-01, 6.5598e-02,
2.8480e-02, -2.0582e-02, 8.1437e-04, -1.1505e-01, 3.2730e-02,
1.2462e-01, -1.3487e-01, -5.3599e-02, -3.4288e-02, 1.3610e-01],
[ 3.3348e-03, 3.9802e-02, -2.8489e-02, 2.0086e-02, 2.1576e-01,
-6.5055e-02, -7.2176e-02, -2.6473e-02, 8.7836e-02, 1.1661e-02,
1.6746e-01, 3.1245e-02, -1.7550e-02, -1.5325e-01, 1.1242e-01,
4.9840e-02, -1.9759e-01, 2.8112e-02, -8.6413e-02, -1.6962e-02],
[-7.4427e-02, 1.1710e-01, -6.9011e-03, -1.0997e-01, 5.7745e-02,
8.5939e-02, -5.8605e-02, -2.1025e-02, 2.3754e-01, -4.8881e-03,
4.5520e-02, 2.7855e-02, 1.5404e-02, -2.2167e-01, -8.8727e-02,
6.0912e-02, -1.5258e-02, -1.6793e-01, -5.7971e-02, -4.9538e-02]],
[[-2.6533e-02, 1.0154e-03, -2.2220e-02, -1.0029e-01, 1.5870e-01,
-3.3663e-02, -9.2402e-02, -7.9979e-03, 1.1908e-01, 5.8657e-02,
9.9747e-02, 1.0687e-02, 1.8266e-02, -1.2505e-01, 1.2888e-02,
9.5571e-02, -6.4849e-02, -5.3283e-02, -1.0791e-01, 1.0686e-01],
[-1.7761e-03, 7.6669e-02, -3.7589e-02, -2.1660e-02, 1.9882e-01,
-4.7095e-02, -6.7763e-02, -3.8129e-02, 1.0340e-01, -2.2646e-02,
1.5637e-01, 5.3467e-02, 3.2359e-02, -1.4168e-01, 9.8835e-02,
6.7367e-02, -1.5017e-01, 9.0694e-03, -1.5094e-01, 4.0554e-02],
[-7.0019e-02, 9.5896e-02, -2.4826e-02, -1.2313e-01, 1.3319e-01,
1.6771e-02, -5.4803e-02, -1.4518e-03, 2.0063e-01, -2.8075e-02,
9.3925e-02, 4.1257e-02, 3.0619e-02, -1.8029e-01, -2.7524e-02,
7.9259e-02, 1.4481e-02, -1.2273e-01, -1.2385e-01, -7.2297e-03]],
[[-8.5072e-03, 3.0404e-02, -1.3843e-02, -1.0111e-01, 1.8537e-01,
-6.3342e-02, -7.6308e-02, -1.4481e-03, 1.0062e-01, 2.3903e-02,
1.2981e-01, 3.1212e-02, 3.4981e-02, -1.3894e-01, 2.9375e-02,
8.0636e-02, -3.9740e-02, -4.5989e-02, -1.5719e-01, 9.7053e-02],
[-8.6290e-03, 7.1298e-02, -3.4861e-02, -5.4390e-02, 1.9481e-01,
-4.2411e-02, -5.5742e-02, -2.8054e-02, 1.2071e-01, -3.2352e-02,
1.5135e-01, 5.1046e-02, 6.2415e-02, -1.3878e-01, 8.7795e-02,
8.3503e-02, -9.5001e-02, 4.8436e-04, -1.8418e-01, 5.5274e-02],
[-4.5381e-02, 8.7877e-02, -2.7749e-02, -1.1548e-01, 1.6753e-01,
-2.6462e-02, -6.3796e-02, 1.3007e-02, 1.9289e-01, -4.2623e-02,
1.0475e-01, 5.0453e-02, 2.0430e-02, -1.5570e-01, -3.6697e-03,
6.9533e-02, 3.9411e-02, -6.6003e-02, -1.5108e-01, 3.7231e-03]]],
grad_fn=<StackBackward>)
torch.Size([5, 3, 20])
tensor([[[ 0.1539, 0.1394, 0.0428, -0.0649, -0.0027, -0.0192, 0.0837,
-0.0667, 0.1427, 0.0111, -0.2659, 0.2150, 0.1839, 0.0269,
-0.0279, -0.0288, 0.1615, -0.2518, -0.1585, -0.1057],
[ 0.0517, 0.0547, -0.1118, -0.0778, 0.0819, -0.0255, -0.0973,
-0.1697, 0.1611, -0.0012, -0.1939, -0.1689, 0.1731, 0.0815,
-0.0412, -0.0041, 0.2454, -0.1557, -0.0754, -0.0730],
[ 0.0637, -0.1256, 0.2050, 0.0290, -0.1689, 0.2387, 0.2580,
-0.0188, 0.0372, 0.0115, -0.0500, 0.1425, 0.1629, 0.0287,
0.0690, -0.0600, 0.2022, -0.2592, -0.0233, -0.1594]],
[[-0.0085, 0.0304, -0.0138, -0.1011, 0.1854, -0.0633, -0.0763,
-0.0014, 0.1006, 0.0239, 0.1298, 0.0312, 0.0350, -0.1389,
0.0294, 0.0806, -0.0397, -0.0460, -0.1572, 0.0971],
[-0.0086, 0.0713, -0.0349, -0.0544, 0.1948, -0.0424, -0.0557,
-0.0281, 0.1207, -0.0324, 0.1514, 0.0510, 0.0624, -0.1388,
0.0878, 0.0835, -0.0950, 0.0005, -0.1842, 0.0553],
[-0.0454, 0.0879, -0.0277, -0.1155, 0.1675, -0.0265, -0.0638,
0.0130, 0.1929, -0.0426, 0.1048, 0.0505, 0.0204, -0.1557,
-0.0037, 0.0695, 0.0394, -0.0660, -0.1511, 0.0037]]],
grad_fn=<StackBackward>)
torch.Size([2, 3, 20])
tensor([[[ 0.3522, 0.2659, 0.0630, -0.1412, -0.0063, -0.0369, 0.1510,
-0.1600, 0.3845, 0.0327, -0.5553, 0.3588, 0.4086, 0.0488,
-0.0647, -0.0562, 0.3051, -0.4392, -0.3177, -0.1919],
[ 0.1868, 0.1283, -0.2144, -0.1914, 0.1720, -0.0421, -0.1722,
-0.3835, 0.3128, -0.0024, -0.3928, -0.3446, 0.2782, 0.1226,
-0.1138, -0.0118, 0.4258, -0.3516, -0.1506, -0.1498],
[ 0.1064, -0.2042, 0.4189, 0.0742, -0.3043, 0.4558, 0.5030,
-0.0542, 0.0815, 0.0236, -0.1197, 0.3385, 0.3621, 0.0485,
0.1882, -0.1620, 0.5038, -0.4896, -0.0457, -0.2475]],
[[-0.0162, 0.0641, -0.0293, -0.2272, 0.4017, -0.1098, -0.1557,
-0.0026, 0.2079, 0.0456, 0.2684, 0.0647, 0.0686, -0.2643,
0.0536, 0.1555, -0.0743, -0.0893, -0.2778, 0.1935],
[-0.0166, 0.1548, -0.0735, -0.1159, 0.4260, -0.0708, -0.1082,
-0.0534, 0.2580, -0.0575, 0.3164, 0.1054, 0.1184, -0.2603,
0.1653, 0.1670, -0.1896, 0.0009, -0.3372, 0.1154],
[-0.0863, 0.1792, -0.0580, -0.2601, 0.3489, -0.0469, -0.1318,
0.0254, 0.4149, -0.0812, 0.2163, 0.1072, 0.0416, -0.2861,
-0.0067, 0.1414, 0.0742, -0.1296, -0.2704, 0.0071]]],
grad_fn=<StackBackward>)
torch.Size([2, 3, 20])
内部结构分析:和之前分析过的LSTM中的门控一样,首先计算更新门和重置门的门值,分别是z(t)和r(t),计算方法就是使用X(t)与h(t-1)拼接进行线性变换,再经过sigmoid激活. 之后重置门门值作用在了h(t-1)上,代表控制上一时间步传来的信息有多少可以被利用. 接着就是使用这个重置后的h(t-1)进行基本的RNN计算,即与x(t)拼接进行线性变化,经过tanh激活,得到新的h(t)。最后更新门的门值会作用在新的h(t),而1-门值会作用在h(t-1)上,随后将两者的结果相加,得到最终的隐含状态输出h(t),这个过程意味着更新门有能力保留之前的结果,当门值趋于1时,输出就是新的h(t),而当门值趋于0时,输出就是上一时间步的h(t-1)。
Bi-GRU和Bi-LSTM的逻辑相同,不改变内部结构,将模型应用两次且方向不同,再将两次得到的结果进行拼接作为输出。
from torch.nn import GRU
import torch
rnn = GRU(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
output, hn = rnn(input, h0)
print(output)
print(output.shape)
print(hn)
print(hn.shape)
tensor([[[-6.9459e-01, -2.6020e-01, -8.8132e-01, -1.0471e+00, -1.5542e+00,
5.9995e-01, -1.0064e+00, 6.1060e-01, -8.9793e-01, -4.6518e-01,
3.3329e-01, -2.0619e-01, 4.2522e-01, 9.1905e-01, 2.7251e-01,
4.4926e-01, 8.2407e-01, -6.3048e-01, 5.4013e-01, -1.2370e-01],
[-1.5277e-02, -3.6983e-01, 6.6451e-01, -1.1350e-01, 3.6478e-01,
4.9056e-01, 4.4266e-02, -5.3051e-02, -7.3330e-01, 5.5783e-01,
7.9211e-01, -1.3003e-01, -2.9227e-01, 1.1308e+00, 4.2109e-01,
-3.6290e-02, 1.6238e-01, -5.2465e-01, -7.3086e-01, 2.6433e-01],
[ 3.4612e-01, -3.6740e-02, 3.5100e-01, 3.3513e-01, 5.4247e-01,
2.8768e-02, -3.9264e-01, 9.4860e-02, 5.1338e-01, -8.6671e-01,
-4.4661e-01, -4.3441e-01, -5.2957e-02, 1.0775e-01, -5.0672e-01,
4.0542e-02, -3.8472e-01, 6.4337e-02, 5.0261e-01, -2.8174e-01]],
[[-3.6883e-01, -3.1634e-01, -6.5082e-01, -6.0509e-01, -9.4691e-01,
5.4833e-01, -5.3345e-01, 3.1349e-01, -5.9341e-01, -2.1943e-01,
6.1886e-02, 1.5400e-02, 3.5497e-01, 6.2187e-01, 3.3893e-01,
2.5864e-01, 5.2170e-01, -4.3129e-01, 2.5660e-01, -1.0916e-02],
[-3.7010e-02, -3.9229e-01, 5.1009e-01, -1.6834e-01, 3.6872e-01,
3.4876e-01, 7.4521e-03, -2.2716e-01, -4.9721e-01, 4.6610e-01,
4.9178e-01, -5.5974e-02, -2.4143e-01, 8.2877e-01, 2.8780e-01,
-1.5498e-01, 1.1125e-01, -3.0916e-01, -5.2575e-01, -6.7723e-02],
[ 1.5300e-01, 6.2819e-02, 1.3036e-01, 2.5860e-01, 2.1189e-01,
1.8184e-01, -3.0030e-01, -4.0117e-02, 1.0346e-01, -3.8004e-01,
-1.6947e-01, -2.5304e-01, -1.5507e-01, 2.5656e-01, -2.6630e-01,
1.3186e-01, -1.7718e-01, 3.9756e-02, 5.1584e-02, -7.1031e-02]],
[[-2.0103e-01, -2.1229e-01, -4.5781e-01, -4.6843e-01, -4.5997e-01,
4.2361e-01, -2.4420e-01, 1.3567e-01, -3.6607e-01, -2.7805e-02,
8.2856e-02, 1.0767e-01, 2.0535e-01, 4.9118e-01, 4.4087e-01,
7.6225e-02, 3.2751e-01, -2.9412e-01, 8.1042e-02, -4.2332e-03],
[-3.8872e-02, -3.6614e-01, 4.3101e-01, -1.2236e-01, 3.5154e-01,
3.7314e-01, -7.8891e-02, -3.3604e-01, -4.1319e-01, 3.7666e-01,
3.0248e-01, -1.0486e-01, -2.3059e-01, 6.7035e-01, 1.2262e-01,
-1.1550e-01, 1.0969e-01, -2.6315e-01, -4.6130e-01, -1.6450e-01],
[-3.7442e-02, 1.5761e-02, 2.8101e-01, 5.1829e-02, 6.7561e-02,
3.6606e-01, -2.6635e-01, -2.0923e-01, -1.2706e-01, -1.4952e-01,
-3.7331e-02, -2.3069e-01, -1.4065e-01, 2.6480e-01, -2.1843e-01,
1.2896e-01, -4.8500e-03, -1.1407e-02, -7.6343e-03, -9.7037e-02]],
[[-8.8321e-02, -1.3655e-01, -4.5159e-01, -2.9163e-01, -1.3257e-01,
2.8671e-01, -1.1048e-01, 3.5248e-02, -1.9131e-01, 4.7080e-04,
1.2787e-01, 1.3459e-01, 1.9933e-01, 4.2335e-01, 5.2391e-01,
3.3715e-02, 2.0858e-01, -2.7867e-01, -9.3280e-03, 3.8756e-02],
[ 1.2003e-02, -3.1858e-01, 3.6018e-01, -5.8120e-02, 3.5037e-01,
3.7452e-01, -4.6973e-02, -3.0096e-01, -2.8530e-01, 2.3618e-01,
1.5918e-01, -6.6228e-02, -1.8211e-01, 4.8781e-01, 3.2872e-02,
-1.1477e-02, 1.2382e-01, -3.0377e-01, -3.0077e-01, -8.1247e-02],
[-2.6341e-02, 2.8667e-03, 2.4752e-01, -1.3758e-01, 8.2617e-02,
4.0154e-01, -2.4415e-01, -2.5054e-01, -2.4248e-01, 1.6624e-02,
1.2189e-01, -1.5927e-01, -1.3366e-01, 3.2527e-01, -8.8212e-02,
1.0427e-01, 8.8365e-02, -1.0926e-01, -1.1473e-01, -8.1202e-02]],
[[-5.4868e-02, -8.6880e-02, -3.0471e-01, -2.6974e-01, 4.6742e-02,
1.9245e-01, -4.1078e-02, -7.0137e-02, -1.1890e-01, -8.7587e-03,
2.3334e-01, 1.7913e-01, 1.2651e-01, 3.6180e-01, 4.3351e-01,
3.4156e-02, 1.4778e-01, -3.2327e-01, 7.1391e-03, -3.2855e-02],
[-8.4417e-03, -2.5009e-01, 4.2045e-01, -8.0115e-02, 2.9774e-01,
3.6584e-01, -1.2970e-01, -2.9642e-01, -2.4022e-01, 1.3007e-01,
1.5285e-01, -4.2584e-02, -1.5034e-01, 3.8386e-01, -5.4949e-02,
-1.6722e-03, 6.9493e-02, -2.6682e-01, -2.6152e-01, -7.6451e-02],
[-3.2398e-02, -4.3465e-02, 2.7941e-01, -1.2414e-01, 1.6196e-01,
4.0282e-01, -1.7708e-01, -2.2177e-01, -2.7108e-01, 1.0402e-01,
1.4508e-01, -1.1615e-01, -1.3291e-01, 3.3067e-01, -4.7963e-02,
7.1976e-02, 1.2032e-01, -1.7992e-01, -1.6073e-01, -6.7630e-02]]],
grad_fn=<StackBackward>)
torch.Size([5, 3, 20])
tensor([[[ 0.3411, -0.1189, 0.2968, -0.0787, 0.3222, -0.1785, -0.2443,
0.2951, 0.0322, -0.2610, -0.2563, -0.0145, -0.1345, 0.1442,
0.1701, -0.5196, 0.2953, -0.3805, -0.2769, 0.4867],
[-0.0700, 0.0500, -0.0204, -0.1923, 0.2298, 0.0201, 0.0578,
0.3886, 0.0550, 0.4585, -0.2719, -0.2821, -0.2649, -0.0644,
-0.0072, 0.1250, -0.3758, 0.2249, -0.0802, 0.1650],
[-0.0533, 0.1476, 0.2704, -0.0181, 0.0572, 0.1012, 0.0101,
0.3611, 0.1178, 0.0130, -0.4630, -0.3046, -0.0808, 0.0264,
0.1128, 0.1478, -0.1283, 0.0871, 0.0025, -0.1136]],
[[-0.0549, -0.0869, -0.3047, -0.2697, 0.0467, 0.1924, -0.0411,
-0.0701, -0.1189, -0.0088, 0.2333, 0.1791, 0.1265, 0.3618,
0.4335, 0.0342, 0.1478, -0.3233, 0.0071, -0.0329],
[-0.0084, -0.2501, 0.4205, -0.0801, 0.2977, 0.3658, -0.1297,
-0.2964, -0.2402, 0.1301, 0.1529, -0.0426, -0.1503, 0.3839,
-0.0549, -0.0017, 0.0695, -0.2668, -0.2615, -0.0765],
[-0.0324, -0.0435, 0.2794, -0.1241, 0.1620, 0.4028, -0.1771,
-0.2218, -0.2711, 0.1040, 0.1451, -0.1162, -0.1329, 0.3307,
-0.0480, 0.0720, 0.1203, -0.1799, -0.1607, -0.0676]]],
grad_fn=<StackBackward>)
torch.Size([2, 3, 20])
在观察事物时,之所以能快速判断一种事物,是因为我们大脑能很快把注意力放在事物最具有辨识度的部分从而做出判断,而并非从头到尾的观察一遍事物之后,才能有判断结果,正是基于这样的理论,产生了注意力机制。
什么是注意力计算规则:
常见的注意力计算规则:
将Q, K进行纵轴拼接,做一次线性变换,再使用softmax处理获得结果最后与V做张量乘法
A t t e n t i o n ( Q , K , V ) = S o f t m a x ( L i n e a r ( [ Q , K ] ) ⋅ V Attention(Q, K, V) = Softmax(Linear([Q, K])·V Attention(Q,K,V)=Softmax(Linear([Q,K])⋅V
将Q, K进行纵轴拼接,做一次线性变换后使用tanh函数激活,然后在进行内部求和,最后使用softmax处理获得结果再与V做张量乘法
A t t e n t i o n ( Q , K , V ) = S o f t m a x ( s u m ( t a n h ( L i n e a r ( [ Q , k ] ) ) ) ) ⋅ V Attention(Q,K,V) = Softmax(sum(tanh(Linear([Q,k]))))·V Attention(Q,K,V)=Softmax(sum(tanh(Linear([Q,k]))))⋅V
将Q与K的转置做点积运算,然后除以一个缩放系数,再使用softmax处理获得结果最后与V做张量乘法。
A t t e n t i o n ( Q , K , V ) = S o f t m a x ( Q ⋅ K T d k ) ⋅ V Attention(Q,K,V) = Softmax(\frac{Q·K^T}{\sqrt{d_k}})·V Attention(Q,K,V)=Softmax(dkQ⋅KT)⋅V
如果注意力权重矩阵和V都是三维张量且第一维代表batch条数时,则做bmm运算,bmm是一种特殊的张量乘法运算。
import torch
input = torch.rand(10, 3, 4)
mat2 = torch.rand(10, 4, 5)
# 如果参数1形状(b*n*m),参数2形状为(b*m*p),则输出为(b*n*p)
res = torch.bmm(input, mat2)
print(res)
print(res.shape)
tensor([[[1.5089, 0.7716, 0.7321, 1.4452, 0.8569],
[1.3220, 0.5150, 0.9651, 0.9369, 0.7922],
[1.1471, 0.5643, 0.6489, 1.1991, 0.6842]],
[[1.0974, 1.6059, 1.0994, 1.4619, 1.3693],
[1.1033, 1.6071, 0.9072, 1.0177, 1.1216],
[1.0452, 1.3168, 1.0817, 1.0244, 0.8911]],
[[1.5674, 0.9442, 1.7458, 1.3762, 0.8584],
[1.4692, 0.9947, 1.6580, 1.3287, 0.8270],
[0.4599, 0.4040, 0.5983, 0.4107, 0.4764]],
[[1.7093, 1.6265, 1.2188, 0.9478, 1.5052],
[1.4540, 1.2161, 1.5406, 0.6138, 1.4907],
[1.5233, 1.1481, 1.1055, 0.8532, 1.5669]],
[[0.4366, 0.7821, 0.4608, 0.5289, 0.2629],
[0.8268, 1.4866, 0.4945, 1.0979, 0.9051],
[0.9979, 1.3817, 0.6230, 1.1372, 0.9892]],
[[0.7191, 0.9421, 1.4818, 0.4685, 0.6745],
[0.7998, 1.3734, 1.5346, 0.5162, 1.1392],
[0.8207, 1.0578, 1.8102, 0.6123, 1.2731]],
[[0.8121, 0.8512, 1.6110, 1.5534, 1.1897],
[0.1294, 0.3024, 0.3043, 0.3204, 0.3466],
[0.6966, 0.6364, 1.6101, 1.1096, 0.9638]],
[[0.3239, 0.6002, 0.4062, 0.6762, 0.3783],
[1.5538, 2.0881, 1.2786, 2.2376, 1.6499],
[1.3914, 2.2420, 1.3859, 2.2466, 1.5593]],
[[1.2858, 0.7075, 0.7151, 0.9685, 0.4747],
[1.4487, 1.0928, 1.0394, 1.1043, 0.6741],
[2.1023, 1.6332, 1.5035, 1.3093, 1.3948]],
[[1.0718, 0.4500, 1.0227, 0.6682, 0.5366],
[0.7295, 0.6790, 0.5071, 0.6125, 0.5268],
[0.9271, 1.2282, 0.7771, 0.7781, 0.9138]]])
torch.Size([10, 3, 5])
在NLP领域中,当前的注意力机制大多数应用于seq2seq架构,即编码器解码器模型
import torch
import torch.nn as nn
import torch.nn.functional as F
class Attn(nn.Module):
def __init__(self, query_size, key_size, value_size1, value_size2, output_size):
'''
Parameters
----------
query_size :
Q的维度(最后一维的大小).
key_size :
K的维度(最后一维的大小).
value_size1 :
V的倒数第二维的大小.
value_size2 :
V的倒数第一维的大小.
output_size : TYPE
最后一次输出的维度大小.
Returns
-------
None.
'''
super(Attn, self).__init__()
# 将以下参数传入类中
self.query_size = query_size
self.key_size = key_size
self.value_size1 = value_size1
self.value_size2 = value_size2
self.output_size = output_size
# 初始化注意力机制实现第一步所需要的线性层
self.attn = nn.Linear(self.query_size + self.key_size, value_size1)
# 初始化注意力机制实现第三部所需要的线性层
self.attn_combine = nn.Linear(self.query_size + value_size2, output_size)
def forward(self, Q, K, V):
'''
Parameters
----------
Q :
query.
K :
key.
V :
value.
Returns
-------
output, attn_weights.
'''
# 第一步,按照计算规则进行计算
# 采用常见的第一种计算规则
# 将Q,K进行纵轴拼接,做一次线性变换,最后使用softmax处理获得结果
attn_weights = F.softmax(
self.attn(torch.cat((Q[0], K[0]), 1)), dim=1)
# 然后进行第一步的后半部分,将得到的权重矩阵与V做矩阵乘法计算
# 当二者都是三维张量且第一维代表batch条数时,则做bmm运算
attn_applied = torch.bmm(attn_weights.unsqueeze(0), V)
# 之后进行第二步,通过取[0]是用来降维,根据第一步采用的计算方法
# 需要将Q与第一步计算结果进行拼接
output = torch.cat((Q[0], attn_applied[0]), 1)
# 最后第三步,使用线性层作用在第三步结果上做一个线性变换并扩展维度,得到输出
# 因为要保证输出也是三维张量,所以使用unsqueeze(0)扩展维度
output = self.attn_combine(output).unsqueeze(0)
return output, attn_weights
query_size = 32
key_size = 32
value_size1 = 32
value_size2 = 64
output_size = 64
attn = Attn(query_size, key_size, value_size1, value_size2, output_size)
Q = torch.randn(1,1,32)
K = torch.randn(1,1,32)
V = torch.randn(1,32,64)
output = attn(Q, K, V)
print(output[0])
print(output[0].size())
print(output[1])
print(output[1].size())
tensor([[[-0.4780, -0.8592, -0.0441, 0.3630, -0.2187, 0.4067, 0.1207,
-0.1605, -0.1946, 0.1040, 0.3175, 0.0843, -0.7464, 0.0476,
-0.3781, -0.1802, 0.3017, -0.0185, -0.7358, -0.0910, -0.4604,
0.2660, 0.0579, -0.2636, -0.0465, 0.0414, 0.1218, -0.1899,
-0.3193, -0.2307, -0.2140, -0.6470, -0.4279, -0.0392, -0.0406,
-0.4341, 0.3089, -0.2082, 0.1910, -0.4590, 0.3347, 0.3408,
-0.4363, 0.0888, -0.4426, 0.0010, 0.0481, 0.0381, -0.2192,
-0.1486, 0.0732, 0.0757, 0.3941, -0.3240, 0.4901, -0.0752,
0.5024, 0.0526, -0.4127, 0.2608, 0.0244, -0.3662, -0.6898,
0.0220]]], grad_fn=<UnsqueezeBackward0>)
torch.Size([1, 1, 64])
tensor([[0.0305, 0.0051, 0.0227, 0.0222, 0.0319, 0.0246, 0.0148, 0.0170, 0.0299,
0.0591, 0.0749, 0.0125, 0.0309, 0.0138, 0.0163, 0.0961, 0.0189, 0.0987,
0.0184, 0.0106, 0.0123, 0.0085, 0.0143, 0.0346, 0.0213, 0.0205, 0.0215,
0.0224, 0.0138, 0.1286, 0.0134, 0.0400]], grad_fn=<SoftmaxBackward>)
torch.Size([1, 32])