本文主要结合WeNet理解Conformer,具体细节可以看原文哦
关于WeNet代码详解可以查看知乎 ‘迷途小书僮’
conformer原文:《Conformer: Convolution-augmented Transformer for Speech Recognition》
提一下大家熟知的transformer原文《attention is all you need》,conformer即在transformer结构上多加了一个convolution而来的。
Conformer整体结构:SpecAug、ConvolutionSubsampling、Linear、Dropout、ConformerBlocks×N。
ConformerBlock结构(N个该结构):Feed Forward Module、Multi-Head Self Attention Module、Convolution Module、Feed Forward Module、Layernorm。其中每个Module都是前接一个Layernorm后接一个Dropout,且都有残差链连接,残差数据为输入数据本身。
马卡龙结构:可以看到ConformerBlock神似马卡龙结构,即两个一样的Feed Forward Module中间夹了Multi-Head Self Attention Module和Convolution Module。
语音识别框架WeNet中encoder结构可选用Conformer,且通过官方文件可以看到,以conformer作为encoder结构比transformer有更好的效果。所以在做wenet训练时候更多的是用conformer作为encoder结构。
主要代码包括encoder.py、subsampling.py、embedding.py、encoder_layer.py、positionwise_feed_forward.py、attention.py、convolution.py。
WeNet中我们提取语音特征之后,会将其送入到encoder中。
在处理数据时进行
代码区:encoder.py–>subsampling.py&embedding.py
a、对输入数据进行下采样,WeNet中有多项选择,即subsampling.py中提到的Conv2dSubsampling4、Conv2dSubsampling6、Conv2dSubsampling8.
b、实验中采用Conv2dSubsampling6、Conv2dSubsampling8训练1wh数据会出现loss为nan现象,因此最终选择Conv2dSubsampling4作为Convolution Subsampling。
c、使用Conv2dSubsampling4对输入数据进行下采样,代码中体现为对输入数据做两次卷积核为3×3,步长为2的Conv2d,做两次以后即对输入数据进行了四倍下采样。
d、Convolution Subsampling部分代码,实验中设置输出维度即odim为512,batch_size为14。
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 2),
torch.nn.ReLU(),
)
e、实验中数据的shape体现
Conv2dSubsampling4之前
(Pdb) p xs.shape
torch.Size([14, 975, 80])
Conv2dSubsampling4之后
(Pdb) p xs.shape
torch.Size([14, 243, 512])
可以看到经过Conv2dSubsampling4之后维度1从975变换到243,即进行了四倍下采样。
代码区:conformerblock: encoder.py–>encoder_layer.py
a、经过前面的处理后正式进入到conformerblock中,实验中使用了N=12层的conformerblock结构。
b、conformerblock公式如下,这里就不赘述啦,就是根据conformerblock结构而来的公式。
代码区:Feed Forward Module: encoder.py–>encoder_layer.py–>positionwise_feed_forward.py
a、结构
如下所示,主要为两个线性层,其中激活函数为swish。conformerblock中包括两个Feed Forward Module,这俩结构是一样的。
b、代码
self.w_1 = torch.nn.Linear(idim, hidden_units)
self.activation = activation
self.dropout = torch.nn.Dropout(dropout_rate)
self.w_2 = torch.nn.Linear(hidden_units, idim)
Feed Forward Module前
(Pdb) p x.shape
torch.Size([14, 243, 512])
(Pdb) p x
tensor([[[ -0.9448, -1.6156, 0.8714, ..., -0.6092, 8.3108, -4.7808],
[ -1.9032, -0.7550, 0.0000, ..., 0.0000, 8.4346, -5.1248],
[ 0.0000, 0.1946, 1.7132, ..., 0.1049, 0.0000, 0.0000],
...,
[-12.1167, -2.8028, -9.4789, ..., -7.2485, 8.0067, -7.4645],
[ -0.5737, -3.4279, -2.4972, ..., -3.8395, 7.2618, -7.5300],
[ -0.0902, -1.9239, -0.6181, ..., 1.2788, 5.7042, 0.0000]],
[[ 0.0000, -3.5419, 0.0000, ..., -4.7302, 0.7809, -0.7692],
[ -4.9884, -1.4948, 0.4247, ..., -3.1005, 0.0594, 0.2251],
[ -5.3625, 0.0000, -1.1701, ..., -2.4318, -1.0293, 1.1975],
...,
[ -9.5743, -7.7568, -13.6071, ..., -5.8194, 12.7505, 0.0000],
[ -9.5743, -7.7568, -13.6071, ..., -5.8194, 12.7505, -7.8113],
[ 0.0000, -7.7568, -13.6071, ..., -5.8194, 12.7505, -7.8113]],
[[ 0.7055, 2.6856, -5.1106, ..., 0.6471, -5.7998, 7.7951],
[ 4.6700, 3.8121, -4.1586, ..., 0.9645, -5.9060, 7.7761],
[ 0.0000, 1.8728, -3.5463, ..., 1.7811, -5.1110, 8.0430],
...,
[ -9.5743, 0.0000, -13.6071, ..., -5.8194, 12.7505, -7.8113],
[ -9.5743, -7.7568, -13.6071, ..., 0.0000, 12.7505, -7.8113],
[ -9.5743, -7.7568, -13.6071, ..., -5.8194, 12.7505, -7.8113]],
...,
[[ -5.9051, 0.0000, 0.0000, ..., -1.7656, 1.2541, -2.9259],
[ -7.2387, 0.0000, -1.9941, ..., -1.9837, 3.7928, -3.0524],
[-11.5061, 4.6892, -2.9696, ..., -3.9930, 0.2395, -1.5693],
...,
[ -9.5743, -7.7568, -13.6071, ..., -5.8194, 12.7505, -7.8113],
[ -9.5743, -7.7568, -13.6071, ..., 0.0000, 12.7505, -7.8113],
[ -9.5743, 0.0000, -13.6071, ..., -5.8194, 12.7505, -7.8113]],
[[ -4.0918, 1.2283, -3.7218, ..., -4.6928, 4.6370, -4.9392],
[ 0.0000, 2.6892, -2.0451, ..., -0.4186, 3.7887, -2.1638],
[ 0.0000, 3.7681, -0.6485, ..., -1.5094, 1.3679, -2.6088],
...,
[ -9.5743, -7.7568, -13.6071, ..., -5.8194, 12.7505, -7.8113],
[ -9.5743, -7.7568, -13.6071, ..., -5.8194, 12.7505, -7.8113],
[ -9.5743, -7.7568, -13.6071, ..., -5.8194, 12.7505, -7.8113]],
[[ -5.7208, 0.5752, -3.4696, ..., -3.7705, 11.7755, 1.7834],
[ -7.1437, 0.8723, -4.2169, ..., -3.9519, 11.1012, 1.8506],
[ -7.0842, 0.5833, -3.3386, ..., -3.5101, 11.4416, 1.6952],
...,
[ -9.5743, -7.7568, -13.6071, ..., 0.0000, 0.0000, -7.8113],
[ -9.5743, -7.7568, -13.6071, ..., -5.8194, 12.7505, -7.8113],
[ -9.5743, -7.7568, -13.6071, ..., -5.8194, 12.7505, 0.0000]]],
device='cuda:0', grad_fn=<FusedDropoutBackward0>)
Feed Forward Module后
(Pdb) p x.shape
torch.Size([14, 243, 512])
(Pdb) p x
7.2815e+00, -7.6011e+00],
[ 3.1489e-02, -1.8858e+00, -4.5769e-01, ..., 1.3952e+00,
5.6306e+00, -3.0690e-02]],
[[-6.3229e-03, -3.4414e+00, -4.1966e-02, ..., -4.7639e+00,
6.3168e-01, -8.2709e-01],
[-4.9884e+00, -1.4948e+00, 4.0035e-01, ..., -2.9774e+00,
-6.4502e-03, 1.0855e-01],
[-5.3848e+00, 2.1896e-02, -1.2273e+00, ..., -2.4318e+00,
-1.0881e+00, 1.1193e+00],
...,
[-9.5743e+00, -7.6100e+00, -1.3619e+01, ..., -5.6579e+00,
1.2705e+01, -1.1259e-01],
[-9.4011e+00, -7.7079e+00, -1.3600e+01, ..., -5.7637e+00,
1.2693e+01, -7.8329e+00],
[ 1.6623e-01, -7.7410e+00, -1.3594e+01, ..., -5.7195e+00,
1.2776e+01, -7.9591e+00]],
[[ 7.0553e-01, 2.7021e+00, -5.1429e+00, ..., 6.9559e-01,
-5.9345e+00, 7.7479e+00],
[ 4.8407e+00, 3.8736e+00, -4.2463e+00, ..., 1.0054e+00,
-5.9870e+00, 7.6246e+00],
[ 1.0486e-01, 1.9964e+00, -3.6064e+00, ..., 1.7598e+00,
-5.1110e+00, 7.9280e+00],
...,
[-9.3324e+00, 1.0171e-01, -1.3625e+01, ..., -5.7424e+00,
1.2724e+01, -7.9598e+00],
[-9.3546e+00, -7.6776e+00, -1.3659e+01, ..., 0.0000e+00,
1.2747e+01, -7.9354e+00],
[-9.3427e+00, -7.7067e+00, -1.3743e+01, ..., -5.8194e+00,
1.2808e+01, -7.8516e+00]],
[[-3.8864e+00, 1.2016e+00, -3.7942e+00, ..., -4.6340e+00,
4.6085e+00, -5.0171e+00],
[ 1.6116e-01, 2.6320e+00, -2.1391e+00, ..., -4.4683e-01,
3.7133e+00, -2.2118e+00],
[ 1.1952e-01, 3.7163e+00, -7.2823e-01, ..., -1.5372e+00,
1.3950e+00, -2.6756e+00],
...,
[-9.3344e+00, -7.7682e+00, -1.3643e+01, ..., -5.6883e+00,
1.2767e+01, -7.9511e+00],
[-9.4534e+00, -7.7568e+00, -1.3671e+01, ..., -5.6858e+00,
1.2750e+01, -7.9653e+00],
[-9.3937e+00, -7.6900e+00, -1.3694e+01, ..., -5.7628e+00,
1.2702e+01, -7.8852e+00]],
[[-5.4615e+00, 5.6451e-01, -3.4696e+00, ..., -3.7742e+00,
1.1608e+01, 1.5476e+00],
[-7.1437e+00, 8.0560e-01, -4.2403e+00, ..., -3.9945e+00,
1.0934e+01, 1.6021e+00],
[-6.8293e+00, 5.6961e-01, -3.4408e+00, ..., -3.4578e+00,
1.1293e+01, 1.3934e+00],
...,
[-9.3903e+00, -7.7132e+00, -1.3642e+01, ..., 1.5418e-01,
-6.0617e-02, -7.9958e+00],
[-9.5743e+00, -7.7289e+00, -1.3659e+01, ..., -5.7391e+00,
1.2684e+01, -7.9137e+00],
[-9.3177e+00, -7.7448e+00, -1.3654e+01, ..., -5.8194e+00,
1.2788e+01, -1.5587e-01]]], device='cuda:0', grad_fn=<AddBackward0>)
可以看到经过Feed Forward Module之后数据维度并不发生改变,但经过layer norm等处理后,数据的大小发生了改变。同理,在另一个Feed Forward Module中也是同样的操作。
代码区:Multi-Head Self Attention Module: encoder.py–>encoder_layer.py–>attention.py
这一部分会开一篇新的来仔细说说
a、结构
Multi-Head Self Attention Module主要由带有相对位置编码多头自注意力(Multi-Head Self Attention with relative positional embedding)组成。实验中用到了头为8的自注意力模块。
b、代码
代码区域为attention.py这里,细节比较多,会开一篇新的说说。
Multi-Head Self Attention Module前
(Pdb) p x.shape
torch.Size([14, 243, 512])
(Pdb) p x
1.2808e+01, -7.8516e+00]],
[[-3.8864e+00, 1.2016e+00, -3.7942e+00, ..., -4.6340e+00,
4.6085e+00, -5.0171e+00],
[ 1.6116e-01, 2.6320e+00, -2.1391e+00, ..., -4.4683e-01,
3.7133e+00, -2.2118e+00],
[ 1.1952e-01, 3.7163e+00, -7.2823e-01, ..., -1.5372e+00,
1.3950e+00, -2.6756e+00],
...,
[-9.3344e+00, -7.7682e+00, -1.3643e+01, ..., -5.6883e+00,
1.2767e+01, -7.9511e+00],
[-9.4534e+00, -7.7568e+00, -1.3671e+01, ..., -5.6858e+00,
1.2750e+01, -7.9653e+00],
[-9.3937e+00, -7.6900e+00, -1.3694e+01, ..., -5.7628e+00,
1.2702e+01, -7.8852e+00]],
[[-5.4615e+00, 5.6451e-01, -3.4696e+00, ..., -3.7742e+00,
1.1608e+01, 1.5476e+00],
[-7.1437e+00, 8.0560e-01, -4.2403e+00, ..., -3.9945e+00,
1.0934e+01, 1.6021e+00],
[-6.8293e+00, 5.6961e-01, -3.4408e+00, ..., -3.4578e+00,
1.1293e+01, 1.3934e+00],
...,
[-9.3903e+00, -7.7132e+00, -1.3642e+01, ..., 1.5418e-01,
-6.0617e-02, -7.9958e+00],
[-9.5743e+00, -7.7289e+00, -1.3659e+01, ..., -5.7391e+00,
1.2684e+01, -7.9137e+00],
[-9.3177e+00, -7.7448e+00, -1.3654e+01, ..., -5.8194e+00,
1.2788e+01, -1.5587e-01]]], device='cuda:0', grad_fn=<AddBackward0>)
Multi-Head Self Attention Module后
(Pdb) p x.shape
torch.Size([14, 243, 512])
(Pdb) p x
tensor([[[-1.0271e+00, -1.9463e+00, 1.7051e+00, ..., -6.3574e-01,
8.7831e+00, -4.8254e+00],
[-1.7611e+00, -7.7109e-01, -9.2721e-02, ..., -6.5186e-02,
8.9391e+00, -5.1968e+00],
[-1.4328e-01, -1.5383e-01, 2.4761e+00, ..., -4.7934e-02,
5.2762e-01, -4.3704e-02],
...,
[-1.2330e+01, -3.0957e+00, -8.8320e+00, ..., -7.2978e+00,
8.4464e+00, -7.7051e+00],
[-5.8125e-01, -3.7171e+00, -1.6540e+00, ..., -3.8881e+00,
7.7412e+00, -7.6324e+00],
[ 3.1489e-02, -1.8858e+00, 3.3470e-01, ..., 1.2515e+00,
5.6306e+00, -4.4427e-02]],
[[-6.3229e-03, -3.3921e+00, 4.8096e-01, ..., -4.8776e+00,
7.2262e-01, -1.3703e+00],
[-5.3446e+00, -1.4451e+00, 9.1622e-01, ..., -3.1011e+00,
7.9614e-02, -4.2886e-01],
[-5.7261e+00, 6.9976e-02, -7.1110e-01, ..., -2.5597e+00,
-9.9130e-01, 5.8256e-01],
...,
[-9.9296e+00, -7.5594e+00, -1.3619e+01, ..., -5.6579e+00,
1.2793e+01, -6.5060e-01],
[-9.7484e+00, -7.6608e+00, -1.3076e+01, ..., -5.8727e+00,
1.2796e+01, -8.3650e+00],
1.2876e+01, -8.1519e+00]],
[[-3.8864e+00, 8.1121e-01, -3.2379e+00, ..., -4.6340e+00,
4.5953e+00, -5.4232e+00],
[-1.0173e-01, 2.2622e+00, -1.6064e+00, ..., -7.8868e-01,
3.7027e+00, -2.6075e+00],
[-1.6345e-01, 3.3597e+00, -1.9494e-01, ..., -1.8532e+00,
1.3767e+00, -3.0669e+00],
...,
[-9.6426e+00, -8.1642e+00, -1.3643e+01, ..., -6.0016e+00,
1.2767e+01, -8.3592e+00],
[-9.6984e+00, -8.1595e+00, -1.3115e+01, ..., -6.0114e+00,
1.2719e+01, -8.3549e+00],
[-9.6625e+00, -8.0812e+00, -1.3120e+01, ..., -6.0365e+00,
1.2702e+01, -8.2902e+00]],
[[-5.4615e+00, 7.9916e-01, -3.4321e+00, ..., -3.8729e+00,
1.1705e+01, 1.5476e+00],
[-6.9851e+00, 1.0249e+00, -4.2206e+00, ..., -4.0995e+00,
1.1052e+01, 1.2745e+00],
[-6.6968e+00, 8.0111e-01, -3.4151e+00, ..., -3.5240e+00,
1.1393e+01, 1.0780e+00],
...,
[-9.2554e+00, -7.4362e+00, -1.3618e+01, ..., 1.5418e-01,
5.9763e-02, -8.3539e+00],
[-9.4399e+00, -7.4798e+00, -1.3644e+01, ..., -5.8081e+00,
1.2792e+01, -8.1860e+00],
[-9.1668e+00, -7.5268e+00, -1.3658e+01, ..., -5.9279e+00,
1.2788e+01, -4.8667e-01]]], device='cuda:0', grad_fn=<AddBackward0>)
可以看到经过Multi-Head Self Attention Module数据维度不发生改变,只是做了一系列的数学计算。
代码区:Convolution Module: encoder.py–>encoder_layer.py–>convolution.py
a、结构
从图中可以看到,Convolution Module包括两个逐点卷积和一个深度卷积,第一个激活函数使用Glu,第二个激活函数使用Swish,中间进行了一次batch norm。逐点卷积为卷积核为1×1,步长为1的Conv1d,深度卷积为卷积核为31×31,步长为1的Conv1d。
x = self.pointwise_conv1(x)
x = nn.functional.glu(x, dim=1)
# 1D Depthwise Conv
x = self.depthwise_conv(x)
if self.use_layer_norm:
x = x.transpose(1, 2)
x = self.activation(self.norm(x))
if self.use_layer_norm:
x = x.transpose(1, 2)
x = self.pointwise_conv2(x)
# mask batch padding
if mask_pad.size(2) > 0:
x.masked_fill_(~mask_pad, 0.0)
Convolution Module前
(Pdb) p x.shape
torch.Size([14, 243, 512])
(Pdb) p x
1.2876e+01, -8.1519e+00]],
[[-3.8864e+00, 8.1121e-01, -3.2379e+00, ..., -4.6340e+00,
4.5953e+00, -5.4232e+00],
[-1.0173e-01, 2.2622e+00, -1.6064e+00, ..., -7.8868e-01,
3.7027e+00, -2.6075e+00],
[-1.6345e-01, 3.3597e+00, -1.9494e-01, ..., -1.8532e+00,
1.3767e+00, -3.0669e+00],
...,
[-9.6426e+00, -8.1642e+00, -1.3643e+01, ..., -6.0016e+00,
1.2767e+01, -8.3592e+00],
[-9.6984e+00, -8.1595e+00, -1.3115e+01, ..., -6.0114e+00,
1.2719e+01, -8.3549e+00],
[-9.6625e+00, -8.0812e+00, -1.3120e+01, ..., -6.0365e+00,
1.2702e+01, -8.2902e+00]],
[[-5.4615e+00, 7.9916e-01, -3.4321e+00, ..., -3.8729e+00,
1.1705e+01, 1.5476e+00],
[-6.9851e+00, 1.0249e+00, -4.2206e+00, ..., -4.0995e+00,
1.1052e+01, 1.2745e+00],
[-6.6968e+00, 8.0111e-01, -3.4151e+00, ..., -3.5240e+00,
1.1393e+01, 1.0780e+00],
...,
[-9.2554e+00, -7.4362e+00, -1.3618e+01, ..., 1.5418e-01,
5.9763e-02, -8.3539e+00],
[-9.4399e+00, -7.4798e+00, -1.3644e+01, ..., -5.8081e+00,
1.2792e+01, -8.1860e+00],
[-9.1668e+00, -7.5268e+00, -1.3658e+01, ..., -5.9279e+00,
1.2788e+01, -4.8667e-01]]], device='cuda:0', grad_fn=<AddBackward0>)
Convolution Module后
(Pdb) p x.shape
torch.Size([14, 243, 512])
(Pdb) p x
[ -9.7045, -8.1045, -13.0602, ..., -0.5921, 12.8330, -8.4895],
[ -9.6829, -8.1501, -13.1456, ..., -6.4132, 12.8319, -8.3690]],
...,
[[ -6.7102, 0.1239, 0.1389, ..., -1.9504, 1.5525, -3.0805],
[ -7.8718, -0.4553, -1.8495, ..., -1.8900, 4.1916, -3.3994],
[-11.9641, 4.3714, -2.5728, ..., -4.5918, 0.2353, -1.9963],
...,
[ -9.9326, -8.3047, -12.9756, ..., -5.7134, 12.7981, -8.2065],
[ -9.8994, -8.2996, -13.6337, ..., 0.1135, 12.7336, -8.2202],
[ -9.7145, -0.6672, -12.9155, ..., -5.7244, 12.8760, -8.1519]],
[[ -3.7198, 1.2760, -3.5933, ..., -4.8444, 4.7165, -5.4366],
[ -0.0161, 2.6363, -1.8160, ..., -1.1275, 3.7269, -2.8670],
[ 0.0800, 3.8032, -0.1034, ..., -1.9010, 1.3917, -3.0669],
...,
[ -9.6426, -8.1642, -13.6431, ..., -6.0016, 12.7671, -8.3592],
[ -9.6984, -8.1595, -13.1149, ..., -6.0114, 12.7185, -8.3549],
[ -9.6625, -8.0812, -13.1198, ..., -6.0365, 12.7020, -8.2902]],
[[ -5.7886, 1.3010, -3.4402, ..., -3.9976, 12.0589, 1.7659],
[ -6.8290, 1.5354, -4.4869, ..., -4.4033, 11.0707, 1.5978],
[ -6.7460, 0.8011, -3.4777, ..., -3.5240, 11.3933, 1.4383],
...,
[ -9.2554, -7.4362, -13.6181, ..., 0.1542, 0.0598, -8.3539],
[ -9.4399, -7.4798, -13.6437, ..., -5.8081, 12.7918, -8.1860],
[ -9.1668, -7.5268, -13.6582, ..., -5.9279, 12.7876, -0.4867]]],
device='cuda:0', grad_fn=<AddBackward0>)
可以看到经过convolution之后数据维度未发生改变,数据大小在经过三次卷积与其他操作后发生了改变
以上为Conformer作为encoder结构在WeNet中的使用,wenet中总体结构与conformer论文一致,但是由于输入的语音数据长度不等的问题,因此在处理数据时会引入mask的操作,mask也是很重要的,这里可以结合代码进行查看。
减小参数量:Conformer结构参数量很大,未来可以考虑对conformer中的Multi-Head Self Attention Module与Convolution Module部分进行优化,减少参数量以减小模型大小。
优化模型效果:confomerblock中结构都为前接Layernorm后接Dropout,且带有残差连接,未来可考虑对残差连接进行修改,能不能提高识别准确率?或者对于loss而言,是否有别的改进?
针对wenet中的Conformer还有很多可以改进的地方,期待大家评论一起讨论。