自然语言处理入门——RNN架构解析

自然语言处理入门

RNN架构解析

认识RNN模型

  • RNN:中文称循环神经网络,一般以序列数据为输入,通过网络内部结构设计有效捕捉序列之间的关系特征,一般也是以序列形式进行输出。

  • RNN单层网络结构:

自然语言处理入门——RNN架构解析_第1张图片

  • 以时间步对RNN进行展开的单层网络结构:(这样看起来就和CNN比较像了)

自然语言处理入门——RNN架构解析_第2张图片

  • RNN的循环机制使模型隐层上一时间步产生的结果能够作为当下时间步输入的一部分。

  • 因为RNN结构能够很好的利用序列间的关系,所以针对自然界有连续性的输入如人类的语言,语音等有很好的处理,广泛应用于NLP领域的各项任务,如文本分类、情感分析、意图识别、机器翻译等。

  • 假设用户输入了What time is it? 下面是RNN处理的方式:

自然语言处理入门——RNN架构解析_第3张图片

  • 最后将最终的输出O5进行处理来解析用户意图。

自然语言处理入门——RNN架构解析_第4张图片

  • RNN模型的分类:
  • 从输入和输出结构进行分类:
    • N vs N - RNN
    • N vs 1 - RNN
    • 1 vs N - RNN
    • N vs M - RNN
  • 从RNN的内部构造进行分类:
    • 传统RNN
    • LSTM
    • Bi-LSTM
    • GRU
    • Bi-GRU
  • N vs N - RNN:输入和输出等长,一般用于生成等长度的诗句。
  • N vs 1 - RNN:要求输出是一个单独的值,只要在最后一个隐层的输出h上进行线性变换就可以了。大部分情况下为了明确结果,还要使用sigmoid或softmax进行处理,这种结构经常被应用在文本分类问题上。
  • 1 vs N - RNN:唯一的输入作用于每次的输出,可用于将图片生成文字任务等。
  • N vs M - RNN:这是一种不限输入输出长度的RNN结构,由编码器和解码器两部分组成,两部分内部结构都是某类RNN,也被称为seq2seq架构,输入数据先通过编码器,最终输出一个隐含变量c,再将c作用在解码器解码的每一时间步上,以保证输入信息被有效利用。

自然语言处理入门——RNN架构解析_第5张图片

  • seq2seq架构最早被提出应用于机器翻译,因为其输入输出不受限,如今也是应用最广的RNN模型结构,在机器翻译、阅读理解、文本摘要等众多领域都进行了非常多的应用实践。

传统RNN模型

自然语言处理入门——RNN架构解析_第6张图片

  • 以中间的方块为例,它的输入有两部分,分别是h(t-1)和x(t),代表上一时间步的隐层输出和此时间步的输入。他们进入到RNN结构体后,会“融合”一起,其实就是进行拼接,形成新的张量[x(t), h(t-1)],这个张量将通过一个全连接层(线性层),使用tanh双曲正切作为激活函数,最终得到该时间步的输出h(t),它将作为下一时间步的输入和x(t+1)一起进入结构体。

自然语言处理入门——RNN架构解析_第7张图片

  • h t = t a n h ( W t [ X t , h t − 1 ] + b t ) h_t = tanh(W_t[X_t, h_{t-1}] + b_t) ht=tanh(Wt[Xt,ht1]+bt)

  • 激活函数tanh:用于帮助调节流经神经网络的值,tanh函数将值压缩在-1和1之间。

自然语言处理入门——RNN架构解析_第8张图片

pytorch中RNN的应用:

  • RNN类在torch.nn.RNN中,其初始化主要参数解释:

    • input_size: The number of expected features in the input `x`. 输入张量x中特征维度大小
    • hidden_size: The number of features in the hidden state `h`. 隐层张量h中特征维度大小
    • num_layers: Number of recurrent layers. 隐含层数量
    • nonlinearity: The non-linearity to use. Can be either 'tanh' or 'relu'. Default: 'tanh'
  • RNN类实例化对象主要参数解释:

    • input: 输入张量x
    • h0: 初始化的隐层张量h
import torch
import torch.nn as nn

rnn = nn.RNN(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
output, hn = rnn(input, h0)
print(output)
print(output.shape)

print(hn)
print(hn.shape)

output:
tensor([[[-0.1823,  0.0859, -0.3405, -0.5000, -0.6306,  0.5065,  0.8202,
          -0.2139,  0.5886, -0.1549,  0.3035,  0.3669, -0.3702, -0.0026,
           0.0604,  0.1055, -0.0163, -0.2904,  0.2216,  0.0020],
         [ 0.2791, -0.1390,  0.3652,  0.0539, -0.5179,  0.7433, -0.0418,
           0.8043,  0.5498, -0.0131,  0.4987,  0.8964,  0.0033,  0.1708,
          -0.0594, -0.0106, -0.5742,  0.5557, -0.3524, -0.3199],
         [-0.0448,  0.2398, -0.1254,  0.5049,  0.6504,  0.6963,  0.5681,
           0.5640,  0.2442, -0.6644,  0.2833,  0.7397,  0.0966,  0.4050,
           0.5397,  0.1153, -0.5372, -0.4970,  0.0586, -0.1714]],

        [[ 0.5013,  0.1563,  0.1514,  0.1719, -0.3103, -0.4294, -0.6875,
           0.0665, -0.4604,  0.1708,  0.1925,  0.0077, -0.2452, -0.1904,
           0.4462,  0.0012, -0.2967, -0.5996, -0.0416,  0.0766],
         [ 0.6177,  0.4556,  0.3853,  0.2834,  0.3121, -0.1427, -0.4408,
          -0.1028, -0.7400,  0.2298, -0.5990,  0.4145,  0.1973,  0.1061,
           0.3418, -0.1150, -0.2209, -0.5048,  0.0269,  0.4954],
         [ 0.1053,  0.3156,  0.2890,  0.2079,  0.0477,  0.2353, -0.0389,
          -0.0014, -0.2171, -0.0972,  0.0658,  0.4972,  0.2478,  0.0355,
           0.4458,  0.0405, -0.5211, -0.2562, -0.1064,  0.5259]],

        [[-0.4234,  0.1803,  0.1560,  0.4580, -0.2345, -0.2388,  0.3107,
          -0.0058,  0.0634,  0.0977, -0.4543,  0.0582, -0.0860,  0.2199,
           0.1864, -0.5531,  0.5284, -0.2800, -0.0510,  0.0912],
         [-0.5156,  0.4014,  0.0628,  0.3032, -0.0117, -0.1661,  0.5899,
           0.1559,  0.2996, -0.4454,  0.0348,  0.0651, -0.5742,  0.2271,
           0.1080, -0.3659,  0.6118, -0.4189,  0.0549, -0.0393],
         [-0.0868,  0.5991,  0.1813,  0.5599,  0.3917, -0.3454,  0.0961,
           0.0566,  0.0284, -0.3377,  0.0170, -0.1184, -0.5352,  0.3805,
           0.1599, -0.1647,  0.2100, -0.5550,  0.1266,  0.2302]],

        [[-0.3534,  0.1374,  0.1209,  0.0387, -0.1049, -0.2417,  0.1742,
          -0.2224,  0.5119, -0.6369,  0.3746,  0.4883, -0.1907,  0.3288,
           0.1200,  0.0569,  0.0759, -0.1567,  0.3188,  0.2419],
         [ 0.0486,  0.3413,  0.1351,  0.5912, -0.3284, -0.3300, -0.0787,
           0.1665,  0.1738, -0.2786,  0.3029,  0.0880,  0.3581, -0.0811,
           0.4021, -0.3304, -0.2823, -0.2832,  0.1019,  0.3242],
         [-0.1608,  0.1246,  0.0863,  0.3260, -0.2099,  0.2095,  0.4521,
           0.4346,  0.3898, -0.4924,  0.2472,  0.2306,  0.3713, -0.0955,
           0.3075, -0.2875, -0.1641, -0.2343, -0.1563,  0.3321]],

        [[ 0.0825,  0.1601,  0.3947,  0.4077,  0.2677, -0.4913, -0.3839,
          -0.1458, -0.0360,  0.2327,  0.0738,  0.2608,  0.3539, -0.0649,
           0.2555, -0.2991, -0.2076, -0.2232, -0.2227,  0.1667],
         [ 0.1509,  0.4328,  0.1422,  0.5717, -0.1864, -0.0670,  0.1437,
           0.2307, -0.2267,  0.0637, -0.2936,  0.2159,  0.4971,  0.0890,
           0.2068, -0.5263,  0.3438, -0.2558,  0.1983,  0.4059],
         [ 0.2452,  0.6272,  0.1001,  0.2857,  0.2405, -0.3052, -0.0981,
           0.1137, -0.1352, -0.3110,  0.1155,  0.0432, -0.0898, -0.0357,
           0.1601, -0.0056, -0.1636, -0.4776,  0.1329,  0.3925]]],
       grad_fn=<StackBackward>)
torch.Size([5, 3, 20])

hn:
tensor([[[ 0.4038,  0.0203,  0.0424,  0.3863, -0.0233, -0.4767, -0.2328,
           0.6005, -0.1970, -0.1546,  0.1492, -0.4493, -0.8081,  0.7578,
          -0.6790,  0.1135, -0.1818, -0.0088, -0.2488, -0.1144],
         [-0.1782, -0.4026,  0.4985, -0.5878,  0.4833,  0.5260,  0.0710,
           0.5741,  0.3153, -0.0460,  0.1516,  0.3593, -0.7491,  0.4448,
          -0.6297,  0.2588,  0.4649, -0.2219, -0.5977,  0.3895],
         [-0.1429, -0.0431,  0.7642, -0.3143,  0.4679,  0.1455,  0.3204,
          -0.2070, -0.1016, -0.4045, -0.3219, -0.2693, -0.6370, -0.0010,
          -0.5872, -0.5141,  0.0144, -0.4947,  0.5004, -0.5219]],

        [[ 0.0825,  0.1601,  0.3947,  0.4077,  0.2677, -0.4913, -0.3839,
          -0.1458, -0.0360,  0.2327,  0.0738,  0.2608,  0.3539, -0.0649,
           0.2555, -0.2991, -0.2076, -0.2232, -0.2227,  0.1667],
         [ 0.1509,  0.4328,  0.1422,  0.5717, -0.1864, -0.0670,  0.1437,
           0.2307, -0.2267,  0.0637, -0.2936,  0.2159,  0.4971,  0.0890,
           0.2068, -0.5263,  0.3438, -0.2558,  0.1983,  0.4059],
         [ 0.2452,  0.6272,  0.1001,  0.2857,  0.2405, -0.3052, -0.0981,
           0.1137, -0.1352, -0.3110,  0.1155,  0.0432, -0.0898, -0.0357,
           0.1601, -0.0056, -0.1636, -0.4776,  0.1329,  0.3925]]],
       grad_fn=<StackBackward>)
torch.Size([2, 3, 20])

传统RNN模型的优缺点

  • 优势:

    • 由于内部结构简单,对计算资源要求低,相比之后学的RNN的变体参数总量少了很多,在短序列任务上性能、效果都表现优异。
  • 劣势:

    • 传统RNN在解决长序列之间的关联时,通过实践证明其表现很差,原因是在进行反向传播时,过长的序列导致梯度计算异常,发生梯度消失或爆炸。
  • 什么是梯度消失或爆炸:

    • 根据反向传播算法及链式法则,梯度的计算可以简化为以下公式

    • D n = σ ′ ( z 1 ) w 1 ⋅ σ ′ ( z 2 ) w 2 ⋅ . . . ⋅ σ ′ ( z n ) w n D_n = σ'(z_1)w_1 · σ'(z_2)w_2 · ... ·σ'(z_n)w_n Dn=σ(z1)w1σ(z2)w2...σ(zn)wn

    • 其中sigmoid的导数值是固定的,在[0, 0.25]之间,而一旦公式中w也小于1,那么通过这样的公式连乘之后,最终的梯度会变得非常非常小,这种现象称作梯度消失,反正,如果认为增大w的值,使其大于1,最终可能造成梯度过大,称作梯度爆炸。

    • 如果梯度消失,权重将无法被更新,最终导致训练失败。梯度爆炸带来的梯度过大,大幅更新网络参数,在极端情况下结果会溢出(NaN)

LSTM模型

  • LSTM也称长短时记忆结构,是传统RNN的变体,与经典RNN相比能够有效捕捉长序列之间的语义关联,缓解梯度消失或爆炸现象。同时LSTM的结构更复杂,它的结构可以分为四个部分解析:
    • 遗忘门
    • 输入门
    • 细胞状态
    • 输出门
  • LSTM的内部结构图如下:

自然语言处理入门——RNN架构解析_第9张图片

  • 遗忘门部分结构图与计算公式

自然语言处理入门——RNN架构解析_第10张图片

  • 遗忘门结构分析:与传统RNN内部结构非常相似,首先将当前时间步输入x(t)与上一时间步隐含状态h(t-1)拼接,得到[x(t), h(t)],然后通过全连接层进行变换,最后通过sigmoid函数进行激活得到f(t),我们可以将f(t)看作是门值,好比一扇门开合的大小程度,门值都将作用在通过该扇门的张量,遗忘门门值将作用的上一层的细胞状态上,代表遗忘过去的多少信息,又因为遗忘门门值是由x(t),h(t-1)计算得来的,因此整个公式意味着根据当前时间步输入和上一个时间步隐含状态h(t-1)来决定遗忘多少上一层的细胞状态所携带的过往信息。
  • 遗忘门内部结构过程演示:

自然语言处理入门——RNN架构解析_第11张图片

  • 激活函数sigmoid的作用:调节流经网络的值,sigmoid函数将值压缩在0和1之间

自然语言处理入门——RNN架构解析_第12张图片

  • 输入门部分结构图与计算公式:

自然语言处理入门——RNN架构解析_第13张图片

  • 输入门结构分析:第一个公式产生输入门门值的公式,和遗忘门公式几乎相同,区别只是在之后要作用的目标不同,这个公式意味着输入信息有多少要进行过滤,输入们的第二个公式与传统RNN内部结构计算相同。对于LSTM来说,得到的是当前细胞状态,而不是像经典RNN一样得到的是隐含状态。
  • 输入门内部结构过程演示:

自然语言处理入门——RNN架构解析_第14张图片

  • 细胞状态更新图与计算公式:

自然语言处理入门——RNN架构解析_第15张图片

  • 细胞状态更新分析:细胞更新的结构与计算公式非常容易理解,这里没有全连接层,只是将刚刚得到的遗忘门门值与上一个时间步得到的C(t-1)相乘,再加上输入门门值与当前时间步得到的未更新C(t)相乘的结果。最终得到更新后的C(t)作为下一个时间步输入的一部分。整个细胞状态更新过程就是对遗忘门和输入门的应用。(这里是乘法,不是卷积)
  • 细胞状态更新过程演示:

自然语言处理入门——RNN架构解析_第16张图片

  • 输出门部分结构图与计算公式:

自然语言处理入门——RNN架构解析_第17张图片

  • 输出门结构分析:输出门部分的公式也是两个,第一个是计算输出门的门值,他和遗忘门,输入门计算公式相同。第二个是使用这个门值产生隐含状态h(t),将作用于更新后的细胞状态C(t)上,并做tanh激活。最终得到h(t)作为下一时间步输入的一部分,整个输出门的过程,就是为了产生隐含状态h(t)。
  • 输出门内部结构过程演示:

自然语言处理入门——RNN架构解析_第18张图片

  • 什么是Bi-LSTM:指双向LSTM,没有改变LSTM本身任何的内部结构,只是将LSTM应用两次且方向不同,再将两次得到的LSTM结果进行拼接作为最终输出。

自然语言处理入门——RNN架构解析_第19张图片

  • Bi-LSTM结构分析:我们看到图中对"我爱中国"这句话或者叫这个输入序列,进行了从左到右和从右到左两次LSTM处理,将得到的结果张量进行了拼接作为最终输出。这种结构能够捕捉语言语法中一些特定的前置或后置特征,增强语义关联,但是模型参数和计算复杂度也随之增加了一倍,一般需要对语料和计算资源进行评估后决定是否使用该结构。

在Pytorch中LSTM工具的使用:

  • 类初始化主要参数解释:

    • input_size: The number of expected features in the input x 输入张量x中特征维度大小
    • hidden_size: The number of features in the hidden state h 隐层张量h中特征维度大小
    • num_layers: Number of recurrent layers. 隐含层数量
    • bidirectional: If True, becomes a bidirectional LSTM. Default: False 是否选择使用双向LSTM,默认不使用。
  • 类实例化对象主要参数解释:

    • input: 输入张量x
    • h0: 初始化的隐层张量h
    • c0: 初始化的细胞状态张量c
  • 输出:

    • output: 每一层的输出
    • h_n: 每一层隐层张量
    • c_n: 细胞状态
import torch.nn as nn
import torch

rnn = nn.LSTM(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
output, (hn, cn) = rnn(input, (h0, c0))

print(output)
print(output.shape)

print(hn)
print(hn.shape)
print(cn)
print(cn.shape)


tensor([[[ 6.4420e-02,  6.2410e-02, -3.1918e-01, -1.7109e-01, -8.8068e-02,
           2.2602e-01, -3.2335e-01, -7.1837e-03,  8.5058e-02,  5.9582e-03,
          -2.9552e-01, -1.1200e-01, -9.9346e-02, -6.7507e-02,  3.3893e-01,
           2.0244e-01, -3.9047e-01, -5.2756e-02,  1.4811e-01,  1.9268e-01],
         [ 4.9299e-02, -8.9591e-02,  1.9591e-01,  1.7071e-01,  4.0422e-01,
          -3.1363e-01, -2.9693e-01,  1.2103e-01,  1.0956e-02,  7.6901e-02,
           2.3720e-01,  6.3182e-03, -3.2968e-01, -1.7734e-01,  1.4216e-01,
          -2.0142e-01, -3.4723e-01,  8.3657e-02,  1.3796e-01, -2.2886e-01],
         [-2.0725e-01,  2.0795e-01,  4.6395e-02,  1.3417e-02, -8.3942e-02,
           8.8684e-02, -1.4023e-01, -4.5434e-02,  3.4325e-01,  5.9968e-02,
          -1.6269e-01, -5.3913e-03,  3.2542e-02, -5.8951e-01, -1.8863e-01,
          -2.8661e-03, -2.8178e-01, -2.7340e-01,  3.1375e-01, -1.5382e-01]],

        [[-1.6560e-02,  2.0795e-02, -1.2222e-01, -9.0507e-02,  1.0746e-03,
           1.1501e-01, -2.1268e-01, -2.4595e-02,  8.3267e-02,  6.6421e-02,
          -9.1167e-02, -6.3106e-02, -2.9599e-02, -1.0097e-01,  1.0818e-01,
           1.6963e-01, -2.5700e-01, -6.5596e-02,  4.6941e-02,  2.0623e-01],
         [ 4.0872e-02,  4.0869e-03,  2.3773e-02,  1.2502e-01,  2.7042e-01,
          -8.8856e-02, -1.0137e-01,  1.1388e-02,  5.1466e-02,  6.7936e-02,
           1.7766e-01,  1.6299e-02, -1.0510e-01, -1.6442e-01,  1.3268e-01,
          -1.3334e-02, -2.8878e-01,  4.8023e-02, -3.1834e-03, -1.1887e-01],
         [-1.0880e-01,  1.4484e-01,  1.1940e-02, -7.1890e-02, -6.4775e-03,
           1.3693e-01, -1.0062e-01, -5.6859e-02,  2.5171e-01,  2.6230e-02,
          -1.9837e-02,  1.8749e-02,  2.6209e-02, -3.7243e-01, -1.3988e-01,
           3.8677e-02, -1.3416e-01, -2.4021e-01,  8.4407e-02, -1.2588e-01]],

        [[-3.1532e-02,  2.4761e-03, -5.6999e-02, -8.2566e-02,  1.0148e-01,
           4.1854e-02, -1.3071e-01, -2.1285e-02,  1.2221e-01,  6.5598e-02,
           2.8480e-02, -2.0582e-02,  8.1437e-04, -1.1505e-01,  3.2730e-02,
           1.2462e-01, -1.3487e-01, -5.3599e-02, -3.4288e-02,  1.3610e-01],
         [ 3.3348e-03,  3.9802e-02, -2.8489e-02,  2.0086e-02,  2.1576e-01,
          -6.5055e-02, -7.2176e-02, -2.6473e-02,  8.7836e-02,  1.1661e-02,
           1.6746e-01,  3.1245e-02, -1.7550e-02, -1.5325e-01,  1.1242e-01,
           4.9840e-02, -1.9759e-01,  2.8112e-02, -8.6413e-02, -1.6962e-02],
         [-7.4427e-02,  1.1710e-01, -6.9011e-03, -1.0997e-01,  5.7745e-02,
           8.5939e-02, -5.8605e-02, -2.1025e-02,  2.3754e-01, -4.8881e-03,
           4.5520e-02,  2.7855e-02,  1.5404e-02, -2.2167e-01, -8.8727e-02,
           6.0912e-02, -1.5258e-02, -1.6793e-01, -5.7971e-02, -4.9538e-02]],

        [[-2.6533e-02,  1.0154e-03, -2.2220e-02, -1.0029e-01,  1.5870e-01,
          -3.3663e-02, -9.2402e-02, -7.9979e-03,  1.1908e-01,  5.8657e-02,
           9.9747e-02,  1.0687e-02,  1.8266e-02, -1.2505e-01,  1.2888e-02,
           9.5571e-02, -6.4849e-02, -5.3283e-02, -1.0791e-01,  1.0686e-01],
         [-1.7761e-03,  7.6669e-02, -3.7589e-02, -2.1660e-02,  1.9882e-01,
          -4.7095e-02, -6.7763e-02, -3.8129e-02,  1.0340e-01, -2.2646e-02,
           1.5637e-01,  5.3467e-02,  3.2359e-02, -1.4168e-01,  9.8835e-02,
           6.7367e-02, -1.5017e-01,  9.0694e-03, -1.5094e-01,  4.0554e-02],
         [-7.0019e-02,  9.5896e-02, -2.4826e-02, -1.2313e-01,  1.3319e-01,
           1.6771e-02, -5.4803e-02, -1.4518e-03,  2.0063e-01, -2.8075e-02,
           9.3925e-02,  4.1257e-02,  3.0619e-02, -1.8029e-01, -2.7524e-02,
           7.9259e-02,  1.4481e-02, -1.2273e-01, -1.2385e-01, -7.2297e-03]],

        [[-8.5072e-03,  3.0404e-02, -1.3843e-02, -1.0111e-01,  1.8537e-01,
          -6.3342e-02, -7.6308e-02, -1.4481e-03,  1.0062e-01,  2.3903e-02,
           1.2981e-01,  3.1212e-02,  3.4981e-02, -1.3894e-01,  2.9375e-02,
           8.0636e-02, -3.9740e-02, -4.5989e-02, -1.5719e-01,  9.7053e-02],
         [-8.6290e-03,  7.1298e-02, -3.4861e-02, -5.4390e-02,  1.9481e-01,
          -4.2411e-02, -5.5742e-02, -2.8054e-02,  1.2071e-01, -3.2352e-02,
           1.5135e-01,  5.1046e-02,  6.2415e-02, -1.3878e-01,  8.7795e-02,
           8.3503e-02, -9.5001e-02,  4.8436e-04, -1.8418e-01,  5.5274e-02],
         [-4.5381e-02,  8.7877e-02, -2.7749e-02, -1.1548e-01,  1.6753e-01,
          -2.6462e-02, -6.3796e-02,  1.3007e-02,  1.9289e-01, -4.2623e-02,
           1.0475e-01,  5.0453e-02,  2.0430e-02, -1.5570e-01, -3.6697e-03,
           6.9533e-02,  3.9411e-02, -6.6003e-02, -1.5108e-01,  3.7231e-03]]],
       grad_fn=<StackBackward>)
torch.Size([5, 3, 20])

tensor([[[ 0.1539,  0.1394,  0.0428, -0.0649, -0.0027, -0.0192,  0.0837,
          -0.0667,  0.1427,  0.0111, -0.2659,  0.2150,  0.1839,  0.0269,
          -0.0279, -0.0288,  0.1615, -0.2518, -0.1585, -0.1057],
         [ 0.0517,  0.0547, -0.1118, -0.0778,  0.0819, -0.0255, -0.0973,
          -0.1697,  0.1611, -0.0012, -0.1939, -0.1689,  0.1731,  0.0815,
          -0.0412, -0.0041,  0.2454, -0.1557, -0.0754, -0.0730],
         [ 0.0637, -0.1256,  0.2050,  0.0290, -0.1689,  0.2387,  0.2580,
          -0.0188,  0.0372,  0.0115, -0.0500,  0.1425,  0.1629,  0.0287,
           0.0690, -0.0600,  0.2022, -0.2592, -0.0233, -0.1594]],

        [[-0.0085,  0.0304, -0.0138, -0.1011,  0.1854, -0.0633, -0.0763,
          -0.0014,  0.1006,  0.0239,  0.1298,  0.0312,  0.0350, -0.1389,
           0.0294,  0.0806, -0.0397, -0.0460, -0.1572,  0.0971],
         [-0.0086,  0.0713, -0.0349, -0.0544,  0.1948, -0.0424, -0.0557,
          -0.0281,  0.1207, -0.0324,  0.1514,  0.0510,  0.0624, -0.1388,
           0.0878,  0.0835, -0.0950,  0.0005, -0.1842,  0.0553],
         [-0.0454,  0.0879, -0.0277, -0.1155,  0.1675, -0.0265, -0.0638,
           0.0130,  0.1929, -0.0426,  0.1048,  0.0505,  0.0204, -0.1557,
          -0.0037,  0.0695,  0.0394, -0.0660, -0.1511,  0.0037]]],
       grad_fn=<StackBackward>)
torch.Size([2, 3, 20])

tensor([[[ 0.3522,  0.2659,  0.0630, -0.1412, -0.0063, -0.0369,  0.1510,
          -0.1600,  0.3845,  0.0327, -0.5553,  0.3588,  0.4086,  0.0488,
          -0.0647, -0.0562,  0.3051, -0.4392, -0.3177, -0.1919],
         [ 0.1868,  0.1283, -0.2144, -0.1914,  0.1720, -0.0421, -0.1722,
          -0.3835,  0.3128, -0.0024, -0.3928, -0.3446,  0.2782,  0.1226,
          -0.1138, -0.0118,  0.4258, -0.3516, -0.1506, -0.1498],
         [ 0.1064, -0.2042,  0.4189,  0.0742, -0.3043,  0.4558,  0.5030,
          -0.0542,  0.0815,  0.0236, -0.1197,  0.3385,  0.3621,  0.0485,
           0.1882, -0.1620,  0.5038, -0.4896, -0.0457, -0.2475]],

        [[-0.0162,  0.0641, -0.0293, -0.2272,  0.4017, -0.1098, -0.1557,
          -0.0026,  0.2079,  0.0456,  0.2684,  0.0647,  0.0686, -0.2643,
           0.0536,  0.1555, -0.0743, -0.0893, -0.2778,  0.1935],
         [-0.0166,  0.1548, -0.0735, -0.1159,  0.4260, -0.0708, -0.1082,
          -0.0534,  0.2580, -0.0575,  0.3164,  0.1054,  0.1184, -0.2603,
           0.1653,  0.1670, -0.1896,  0.0009, -0.3372,  0.1154],
         [-0.0863,  0.1792, -0.0580, -0.2601,  0.3489, -0.0469, -0.1318,
           0.0254,  0.4149, -0.0812,  0.2163,  0.1072,  0.0416, -0.2861,
          -0.0067,  0.1414,  0.0742, -0.1296, -0.2704,  0.0071]]],
       grad_fn=<StackBackward>)
torch.Size([2, 3, 20])
  • LSTM优势:LSTM的们结构能够有效缓解长序列问题中可能出现的梯度消失或爆炸,虽然并不能杜绝这种现象,但在更长的序列问题上表现优于传统RNN
  • LSTM缺点:由于内部结构相对较复杂,因此训练效率在同等算力下较传统的RNN低得多。

GRU模型

  • GRU也称门控循环单元结构,也是传统RNN的变体,同LSTM一样能够有效捕捉长序列之间的语义关联,缓解梯度消失或爆炸现象,同时它的结构和计算要比LSTM简单,它的核心结构可以分为两个部分:
    • 更新门
    • 重置门
  • GRU的内部结构和计算公式如下:

自然语言处理入门——RNN架构解析_第20张图片

  • 把z_t叫做更新门,r_t叫做重置门
  • GRU的更新门和重置门结构图:

自然语言处理入门——RNN架构解析_第21张图片

  • 内部结构分析:和之前分析过的LSTM中的门控一样,首先计算更新门和重置门的门值,分别是z(t)和r(t),计算方法就是使用X(t)与h(t-1)拼接进行线性变换,再经过sigmoid激活. 之后重置门门值作用在了h(t-1)上,代表控制上一时间步传来的信息有多少可以被利用. 接着就是使用这个重置后的h(t-1)进行基本的RNN计算,即与x(t)拼接进行线性变化,经过tanh激活,得到新的h(t)。最后更新门的门值会作用在新的h(t),而1-门值会作用在h(t-1)上,随后将两者的结果相加,得到最终的隐含状态输出h(t),这个过程意味着更新门有能力保留之前的结果,当门值趋于1时,输出就是新的h(t),而当门值趋于0时,输出就是上一时间步的h(t-1)。

  • Bi-GRU和Bi-LSTM的逻辑相同,不改变内部结构,将模型应用两次且方向不同,再将两次得到的结果进行拼接作为输出。

Pytorch中GRU工具的使用

  • GRU类初始化主要参数解释:
    • input_size: 输入张量x中特征维度大小
    • hidden_size: 隐层张量h中特征维度大小
    • num_layers: 隐含层数量
    • bidirectional: 是否双向
  • 实例化对象主要参数解释:
    • input: 输入张量x
    • h_0: 初始化隐层张量h
from torch.nn import GRU
import torch

rnn = GRU(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
output, hn = rnn(input, h0)

print(output)
print(output.shape)
print(hn)
print(hn.shape)


tensor([[[-6.9459e-01, -2.6020e-01, -8.8132e-01, -1.0471e+00, -1.5542e+00,
           5.9995e-01, -1.0064e+00,  6.1060e-01, -8.9793e-01, -4.6518e-01,
           3.3329e-01, -2.0619e-01,  4.2522e-01,  9.1905e-01,  2.7251e-01,
           4.4926e-01,  8.2407e-01, -6.3048e-01,  5.4013e-01, -1.2370e-01],
         [-1.5277e-02, -3.6983e-01,  6.6451e-01, -1.1350e-01,  3.6478e-01,
           4.9056e-01,  4.4266e-02, -5.3051e-02, -7.3330e-01,  5.5783e-01,
           7.9211e-01, -1.3003e-01, -2.9227e-01,  1.1308e+00,  4.2109e-01,
          -3.6290e-02,  1.6238e-01, -5.2465e-01, -7.3086e-01,  2.6433e-01],
         [ 3.4612e-01, -3.6740e-02,  3.5100e-01,  3.3513e-01,  5.4247e-01,
           2.8768e-02, -3.9264e-01,  9.4860e-02,  5.1338e-01, -8.6671e-01,
          -4.4661e-01, -4.3441e-01, -5.2957e-02,  1.0775e-01, -5.0672e-01,
           4.0542e-02, -3.8472e-01,  6.4337e-02,  5.0261e-01, -2.8174e-01]],

        [[-3.6883e-01, -3.1634e-01, -6.5082e-01, -6.0509e-01, -9.4691e-01,
           5.4833e-01, -5.3345e-01,  3.1349e-01, -5.9341e-01, -2.1943e-01,
           6.1886e-02,  1.5400e-02,  3.5497e-01,  6.2187e-01,  3.3893e-01,
           2.5864e-01,  5.2170e-01, -4.3129e-01,  2.5660e-01, -1.0916e-02],
         [-3.7010e-02, -3.9229e-01,  5.1009e-01, -1.6834e-01,  3.6872e-01,
           3.4876e-01,  7.4521e-03, -2.2716e-01, -4.9721e-01,  4.6610e-01,
           4.9178e-01, -5.5974e-02, -2.4143e-01,  8.2877e-01,  2.8780e-01,
          -1.5498e-01,  1.1125e-01, -3.0916e-01, -5.2575e-01, -6.7723e-02],
         [ 1.5300e-01,  6.2819e-02,  1.3036e-01,  2.5860e-01,  2.1189e-01,
           1.8184e-01, -3.0030e-01, -4.0117e-02,  1.0346e-01, -3.8004e-01,
          -1.6947e-01, -2.5304e-01, -1.5507e-01,  2.5656e-01, -2.6630e-01,
           1.3186e-01, -1.7718e-01,  3.9756e-02,  5.1584e-02, -7.1031e-02]],

        [[-2.0103e-01, -2.1229e-01, -4.5781e-01, -4.6843e-01, -4.5997e-01,
           4.2361e-01, -2.4420e-01,  1.3567e-01, -3.6607e-01, -2.7805e-02,
           8.2856e-02,  1.0767e-01,  2.0535e-01,  4.9118e-01,  4.4087e-01,
           7.6225e-02,  3.2751e-01, -2.9412e-01,  8.1042e-02, -4.2332e-03],
         [-3.8872e-02, -3.6614e-01,  4.3101e-01, -1.2236e-01,  3.5154e-01,
           3.7314e-01, -7.8891e-02, -3.3604e-01, -4.1319e-01,  3.7666e-01,
           3.0248e-01, -1.0486e-01, -2.3059e-01,  6.7035e-01,  1.2262e-01,
          -1.1550e-01,  1.0969e-01, -2.6315e-01, -4.6130e-01, -1.6450e-01],
         [-3.7442e-02,  1.5761e-02,  2.8101e-01,  5.1829e-02,  6.7561e-02,
           3.6606e-01, -2.6635e-01, -2.0923e-01, -1.2706e-01, -1.4952e-01,
          -3.7331e-02, -2.3069e-01, -1.4065e-01,  2.6480e-01, -2.1843e-01,
           1.2896e-01, -4.8500e-03, -1.1407e-02, -7.6343e-03, -9.7037e-02]],

        [[-8.8321e-02, -1.3655e-01, -4.5159e-01, -2.9163e-01, -1.3257e-01,
           2.8671e-01, -1.1048e-01,  3.5248e-02, -1.9131e-01,  4.7080e-04,
           1.2787e-01,  1.3459e-01,  1.9933e-01,  4.2335e-01,  5.2391e-01,
           3.3715e-02,  2.0858e-01, -2.7867e-01, -9.3280e-03,  3.8756e-02],
         [ 1.2003e-02, -3.1858e-01,  3.6018e-01, -5.8120e-02,  3.5037e-01,
           3.7452e-01, -4.6973e-02, -3.0096e-01, -2.8530e-01,  2.3618e-01,
           1.5918e-01, -6.6228e-02, -1.8211e-01,  4.8781e-01,  3.2872e-02,
          -1.1477e-02,  1.2382e-01, -3.0377e-01, -3.0077e-01, -8.1247e-02],
         [-2.6341e-02,  2.8667e-03,  2.4752e-01, -1.3758e-01,  8.2617e-02,
           4.0154e-01, -2.4415e-01, -2.5054e-01, -2.4248e-01,  1.6624e-02,
           1.2189e-01, -1.5927e-01, -1.3366e-01,  3.2527e-01, -8.8212e-02,
           1.0427e-01,  8.8365e-02, -1.0926e-01, -1.1473e-01, -8.1202e-02]],

        [[-5.4868e-02, -8.6880e-02, -3.0471e-01, -2.6974e-01,  4.6742e-02,
           1.9245e-01, -4.1078e-02, -7.0137e-02, -1.1890e-01, -8.7587e-03,
           2.3334e-01,  1.7913e-01,  1.2651e-01,  3.6180e-01,  4.3351e-01,
           3.4156e-02,  1.4778e-01, -3.2327e-01,  7.1391e-03, -3.2855e-02],
         [-8.4417e-03, -2.5009e-01,  4.2045e-01, -8.0115e-02,  2.9774e-01,
           3.6584e-01, -1.2970e-01, -2.9642e-01, -2.4022e-01,  1.3007e-01,
           1.5285e-01, -4.2584e-02, -1.5034e-01,  3.8386e-01, -5.4949e-02,
          -1.6722e-03,  6.9493e-02, -2.6682e-01, -2.6152e-01, -7.6451e-02],
         [-3.2398e-02, -4.3465e-02,  2.7941e-01, -1.2414e-01,  1.6196e-01,
           4.0282e-01, -1.7708e-01, -2.2177e-01, -2.7108e-01,  1.0402e-01,
           1.4508e-01, -1.1615e-01, -1.3291e-01,  3.3067e-01, -4.7963e-02,
           7.1976e-02,  1.2032e-01, -1.7992e-01, -1.6073e-01, -6.7630e-02]]],
       grad_fn=<StackBackward>)
torch.Size([5, 3, 20])

tensor([[[ 0.3411, -0.1189,  0.2968, -0.0787,  0.3222, -0.1785, -0.2443,
           0.2951,  0.0322, -0.2610, -0.2563, -0.0145, -0.1345,  0.1442,
           0.1701, -0.5196,  0.2953, -0.3805, -0.2769,  0.4867],
         [-0.0700,  0.0500, -0.0204, -0.1923,  0.2298,  0.0201,  0.0578,
           0.3886,  0.0550,  0.4585, -0.2719, -0.2821, -0.2649, -0.0644,
          -0.0072,  0.1250, -0.3758,  0.2249, -0.0802,  0.1650],
         [-0.0533,  0.1476,  0.2704, -0.0181,  0.0572,  0.1012,  0.0101,
           0.3611,  0.1178,  0.0130, -0.4630, -0.3046, -0.0808,  0.0264,
           0.1128,  0.1478, -0.1283,  0.0871,  0.0025, -0.1136]],

        [[-0.0549, -0.0869, -0.3047, -0.2697,  0.0467,  0.1924, -0.0411,
          -0.0701, -0.1189, -0.0088,  0.2333,  0.1791,  0.1265,  0.3618,
           0.4335,  0.0342,  0.1478, -0.3233,  0.0071, -0.0329],
         [-0.0084, -0.2501,  0.4205, -0.0801,  0.2977,  0.3658, -0.1297,
          -0.2964, -0.2402,  0.1301,  0.1529, -0.0426, -0.1503,  0.3839,
          -0.0549, -0.0017,  0.0695, -0.2668, -0.2615, -0.0765],
         [-0.0324, -0.0435,  0.2794, -0.1241,  0.1620,  0.4028, -0.1771,
          -0.2218, -0.2711,  0.1040,  0.1451, -0.1162, -0.1329,  0.3307,
          -0.0480,  0.0720,  0.1203, -0.1799, -0.1607, -0.0676]]],
       grad_fn=<StackBackward>)
torch.Size([2, 3, 20])
  • GRU的优势:GRU和LSTM的作用相同,在捕捉长序列语义关联时,能有效抑制梯度消失或爆炸,效果都优于传统RNN且计算复杂度比LSTM小
  • GRU的缺点:仍不能完全解决梯度消失问题,同时其作为RNN的变体,有着RNN结构的一大弊端:不可并行计算

注意力机制

  • 在观察事物时,之所以能快速判断一种事物,是因为我们大脑能很快把注意力放在事物最具有辨识度的部分从而做出判断,而并非从头到尾的观察一遍事物之后,才能有判断结果,正是基于这样的理论,产生了注意力机制。

  • 什么是注意力计算规则:

    • 需要三个指定的输入Q(query), K(key), V(value),然后通过计算公式得到注意力的结果,这个结果代表query在key和value作用下的注意力表示,当输入的Q=K=V时,称作自注意力计算规则。
  • 常见的注意力计算规则:

    • 将Q, K进行纵轴拼接,做一次线性变换,再使用softmax处理获得结果最后与V做张量乘法

    • A t t e n t i o n ( Q , K , V ) = S o f t m a x ( L i n e a r ( [ Q , K ] ) ⋅ V Attention(Q, K, V) = Softmax(Linear([Q, K])·V Attention(Q,K,V)=Softmax(Linear([Q,K])V

    • 将Q, K进行纵轴拼接,做一次线性变换后使用tanh函数激活,然后在进行内部求和,最后使用softmax处理获得结果再与V做张量乘法

    • A t t e n t i o n ( Q , K , V ) = S o f t m a x ( s u m ( t a n h ( L i n e a r ( [ Q , k ] ) ) ) ) ⋅ V Attention(Q,K,V) = Softmax(sum(tanh(Linear([Q,k]))))·V Attention(Q,K,V)=Softmax(sum(tanh(Linear([Q,k]))))V

    • 将Q与K的转置做点积运算,然后除以一个缩放系数,再使用softmax处理获得结果最后与V做张量乘法。

    • A t t e n t i o n ( Q , K , V ) = S o f t m a x ( Q ⋅ K T d k ) ⋅ V Attention(Q,K,V) = Softmax(\frac{Q·K^T}{\sqrt{d_k}})·V Attention(Q,K,V)=Softmax(dk QKT)V

  • 如果注意力权重矩阵和V都是三维张量且第一维代表batch条数时,则做bmm运算,bmm是一种特殊的张量乘法运算。

import torch

input = torch.rand(10, 3, 4)
mat2 = torch.rand(10, 4, 5)
# 如果参数1形状(b*n*m),参数2形状为(b*m*p),则输出为(b*n*p)
res = torch.bmm(input, mat2)
print(res)
print(res.shape)


tensor([[[1.5089, 0.7716, 0.7321, 1.4452, 0.8569],
         [1.3220, 0.5150, 0.9651, 0.9369, 0.7922],
         [1.1471, 0.5643, 0.6489, 1.1991, 0.6842]],

        [[1.0974, 1.6059, 1.0994, 1.4619, 1.3693],
         [1.1033, 1.6071, 0.9072, 1.0177, 1.1216],
         [1.0452, 1.3168, 1.0817, 1.0244, 0.8911]],

        [[1.5674, 0.9442, 1.7458, 1.3762, 0.8584],
         [1.4692, 0.9947, 1.6580, 1.3287, 0.8270],
         [0.4599, 0.4040, 0.5983, 0.4107, 0.4764]],

        [[1.7093, 1.6265, 1.2188, 0.9478, 1.5052],
         [1.4540, 1.2161, 1.5406, 0.6138, 1.4907],
         [1.5233, 1.1481, 1.1055, 0.8532, 1.5669]],

        [[0.4366, 0.7821, 0.4608, 0.5289, 0.2629],
         [0.8268, 1.4866, 0.4945, 1.0979, 0.9051],
         [0.9979, 1.3817, 0.6230, 1.1372, 0.9892]],

        [[0.7191, 0.9421, 1.4818, 0.4685, 0.6745],
         [0.7998, 1.3734, 1.5346, 0.5162, 1.1392],
         [0.8207, 1.0578, 1.8102, 0.6123, 1.2731]],

        [[0.8121, 0.8512, 1.6110, 1.5534, 1.1897],
         [0.1294, 0.3024, 0.3043, 0.3204, 0.3466],
         [0.6966, 0.6364, 1.6101, 1.1096, 0.9638]],

        [[0.3239, 0.6002, 0.4062, 0.6762, 0.3783],
         [1.5538, 2.0881, 1.2786, 2.2376, 1.6499],
         [1.3914, 2.2420, 1.3859, 2.2466, 1.5593]],

        [[1.2858, 0.7075, 0.7151, 0.9685, 0.4747],
         [1.4487, 1.0928, 1.0394, 1.1043, 0.6741],
         [2.1023, 1.6332, 1.5035, 1.3093, 1.3948]],

        [[1.0718, 0.4500, 1.0227, 0.6682, 0.5366],
         [0.7295, 0.6790, 0.5071, 0.6125, 0.5268],
         [0.9271, 1.2282, 0.7771, 0.7781, 0.9138]]])
torch.Size([10, 3, 5])
  • 什么是注意力机制:注意力机制是注意力计算规则能够应用的深度学习网络的载体,同时包括一些必要的全连接层以及相关张量处理,使其与应用网络融为一体,使自注意力规则的注意力机制称为自注意力机制。

在NLP领域中,当前的注意力机制大多数应用于seq2seq架构,即编码器解码器模型

  • 注意力机制的作用:
    • 在解码器端的注意力机制:能够根据模型目标有效的聚焦编码器的输出结果,当其作为解码器的输入时提升效果,改善以往编码器输出单一定长张量无法存储过多信息的情况。
    • 在编码器段的注意力机制:主要解决表征问题,相当于特征提取过程,得到输入的注意力表示,一般使用自注意力(self-attention)
  • 注意力机制实现步骤:
    • 第一步:根据注意力计算规则,对Q,K,V进行相应计算
    • 第二步:根据第一步采用的计算方法,如果是拼接方法,则需要将Q与第二步的计算结果再进行拼接,如果是转置点积,一般是自注意力,Q与V相同,则不需要进行与Q的拼接。
    • 第三步:最后为了整个attention机制按照指定尺寸输出,使用线性层作用在第二步的结果上做一个线性变换,得到最终对Q的注意力表示。
  • 常见注意力机制的代码分析:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Attn(nn.Module):
    def __init__(self, query_size, key_size, value_size1, value_size2, output_size):
        '''
        

        Parameters
        ----------
        query_size : 
            Q的维度(最后一维的大小).
        key_size : 
            K的维度(最后一维的大小).
        value_size1 : 
            V的倒数第二维的大小.
        value_size2 : 
            V的倒数第一维的大小.
        output_size : TYPE
            最后一次输出的维度大小.

        Returns
        -------
        None.

        '''
        super(Attn, self).__init__()
        # 将以下参数传入类中
        self.query_size = query_size
        self.key_size = key_size
        self.value_size1 = value_size1
        self.value_size2 = value_size2
        self.output_size = output_size
        
        # 初始化注意力机制实现第一步所需要的线性层
        self.attn = nn.Linear(self.query_size + self.key_size, value_size1)
        
        # 初始化注意力机制实现第三部所需要的线性层
        self.attn_combine = nn.Linear(self.query_size + value_size2, output_size)
        
        
        
    def forward(self, Q, K, V):
        '''
        

        Parameters
        ----------
        Q : 
            query.
        K : 
            key.
        V : 
            value.

        Returns
        -------
        output, attn_weights.

        '''
        # 第一步,按照计算规则进行计算
        # 采用常见的第一种计算规则
        # 将Q,K进行纵轴拼接,做一次线性变换,最后使用softmax处理获得结果
        attn_weights = F.softmax(
            self.attn(torch.cat((Q[0], K[0]), 1)), dim=1)
        
        # 然后进行第一步的后半部分,将得到的权重矩阵与V做矩阵乘法计算
        # 当二者都是三维张量且第一维代表batch条数时,则做bmm运算
        attn_applied = torch.bmm(attn_weights.unsqueeze(0), V)
        
        # 之后进行第二步,通过取[0]是用来降维,根据第一步采用的计算方法
        # 需要将Q与第一步计算结果进行拼接
        output = torch.cat((Q[0], attn_applied[0]), 1)
        
        # 最后第三步,使用线性层作用在第三步结果上做一个线性变换并扩展维度,得到输出
        # 因为要保证输出也是三维张量,所以使用unsqueeze(0)扩展维度
        output = self.attn_combine(output).unsqueeze(0)
        return output, attn_weights
    
    
    
query_size = 32
key_size = 32
value_size1 = 32
value_size2 = 64
output_size = 64

attn = Attn(query_size, key_size, value_size1, value_size2, output_size)
Q = torch.randn(1,1,32)
K = torch.randn(1,1,32)
V = torch.randn(1,32,64)

output = attn(Q, K, V)
print(output[0])
print(output[0].size())
print(output[1])
print(output[1].size())


tensor([[[-0.4780, -0.8592, -0.0441,  0.3630, -0.2187,  0.4067,  0.1207,
          -0.1605, -0.1946,  0.1040,  0.3175,  0.0843, -0.7464,  0.0476,
          -0.3781, -0.1802,  0.3017, -0.0185, -0.7358, -0.0910, -0.4604,
           0.2660,  0.0579, -0.2636, -0.0465,  0.0414,  0.1218, -0.1899,
          -0.3193, -0.2307, -0.2140, -0.6470, -0.4279, -0.0392, -0.0406,
          -0.4341,  0.3089, -0.2082,  0.1910, -0.4590,  0.3347,  0.3408,
          -0.4363,  0.0888, -0.4426,  0.0010,  0.0481,  0.0381, -0.2192,
          -0.1486,  0.0732,  0.0757,  0.3941, -0.3240,  0.4901, -0.0752,
           0.5024,  0.0526, -0.4127,  0.2608,  0.0244, -0.3662, -0.6898,
           0.0220]]], grad_fn=<UnsqueezeBackward0>)
torch.Size([1, 1, 64])

tensor([[0.0305, 0.0051, 0.0227, 0.0222, 0.0319, 0.0246, 0.0148, 0.0170, 0.0299,
         0.0591, 0.0749, 0.0125, 0.0309, 0.0138, 0.0163, 0.0961, 0.0189, 0.0987,
         0.0184, 0.0106, 0.0123, 0.0085, 0.0143, 0.0346, 0.0213, 0.0205, 0.0215,
         0.0224, 0.0138, 0.1286, 0.0134, 0.0400]], grad_fn=<SoftmaxBackward>)
torch.Size([1, 32])

你可能感兴趣的:(NLP,自然语言处理,rnn,架构)