nn.GRU的batch_first

最近在复现deepspeech来实现语音识别。其网络结构为CNN与GRU,加一个线性分类层。在实现的过程中,代码参考链接:
https://www.assemblyai.com/blog/end-to-end-speech-recognition-pytorch/
他这里的多层GRU写法如下:

self.birnn_layers = nn.Sequential(*[
            BidirectionalGRU(rnn_dim=rnn_dim if i==0 else rnn_dim*2,
                             hidden_size=rnn_dim, dropout=dropout, batch_first=i==0)
            for i in range(n_rnn_layers)
        ])

BidirectionalGRU的写法如下:

class BidirectionalGRU(nn.Module):
    def __init__(self, rnn_dim, hidden_size, dropout, batch_first):
        super(BidirectionalGRU, self).__init__()
        self.BiGRU = nn.GRU(
            input_size=rnn_dim, hidden_size=hidden_size,
            num_layers=1, batch_first=batch_first, bidirectional=True)
        self.layer_norm = nn.LayerNorm(rnn_dim)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        x = self.layer_norm(x)
        x = F.gelu(x)              #激活?
        x, _ = self.BiGRU(x)

        x = self.dropout(x)
        return x

这里他的batch_first是由i==0来控制,i代表第i+1层,由此可知,只有第一层GRU单元的batch_first为True,后面GRU单元的batch_first为False。由pytorch文档的GRU输入规定,这是会出现一定问题的。因为我们输入的形状为[B, N, F],N是帧数,F是频率维度。当输入第一层GRU时,输出为[B, N, hout]。此时输入第二层GRU时,B则被处理为时序序列,N则被处理为批数量大小。这其实也是能训练的,是不会报错的,虽然能得到结果,但是后面会出现一定问题。应当直接将batch_first设置为True。

你可能感兴趣的:(gru,batch,pytorch)