在Pytorch中调用RNN模型的小细节

在pytorch中调用RNN模型的时候,使用self.rnn = nn.RNN(embedding_num,hidden_num)往往忽略了其中的一个参数,在点开RNN源码的时候,可以看到其中batch_first这个参数:在Pytorch中调用RNN模型的小细节_第1张图片

可以看到这个参数如果为True的话,你的输入输出的tensor形状为(batch,maxlen,embedding_num),但是这个参数默认是False的,所以如果忘记了这个参数,要把输入的batch_idx在多加一行transpose的代码,如下:

batch_text_idx = batch_text_idx.transpose(1,0,2)
#也就是说将原来的batch 和 maxlen 要对调

如果说直接设置为True的话,就不用去做转置的部分。 

你可能感兴趣的:(pytorch,rnn,深度学习)