pytorch中LSTM的输入与输出理解

在阅读本篇博客之前希望你在LSTM方面有一定的知识储备,熟悉LSTM网络的内部结构,方便更好的理解pytorch中有关LSTM相关的api。

一、参数理解

这里我根据lstm的结构定义了一些参数,参数具体含义可以看注释

batch_size = 10 #每个batch的大小
seq_len = 2000 #模仿输入到LSTM的句子长度
input_size = 30 #lstm中输入的维度
hidden_size = 18 #lstm中隐藏层神经元的个数
num_layers = 2 # 有多少层lstm

二、数据准备

input = torch.randn(batch_size,seq_len,input_size)

三、LSTM

1、batch_first=True

pytorch中lstm输入和输出分为两种形式,一种是batch优先,另外一种则是batch第二,具体情况是指定lstm种参数batch_first=True,batch_first默认是False对应batch第二的情况,使用中我们一般将batch_first设置为True,采用batch优先的方式。如下

lstm = nn.LSTM(input_size=embedding_dim,hidden_size=hidden_size,num_layers=num_layers,batch_first=True)#,batch_first=True
out,(hn,cn) = lstm(input)
print(out.size())
print('*'*100)
print(hn.size())
print('*'*100)
print(cn.size())

程序输出:

torch.Size([10, 2000, 18])
****************************************************************************************************
torch.Size([2, 10, 18])
****************************************************************************************************
torch.Size([2, 10, 18])

out的size为[10, 2000, 18]分别是[batch_size ,seq_len ,hidden_size] 对应的我们看一下源码中给出的注释

batch_first: If ``True``, then the input and output tensors are provided as
(batch, seq, feature). Default: ``False``

对应的再看下hn,cn的输出都是[num_layers,batch_size,hidden_size]。

2、batch_first=Fasle

现在再来考虑一种情况当batchfirst是默认为Fasle的情况呢?为了和上边有区分我们重新设置输入数据和lstm参数

input = torch.randn(seq_len,embedding_dim)
input = torch.unsqueeze(input,dim=0)#将input变为[1,seq_len,embedding_dim] batch=1
#input = torch.randn(batch_size,seq_len,embedding_dim)#可自行注释前两行看一下结果
lstm = nn.LSTM(input_size=embedding_dim,hidden_size=hidden_size,num_layers=num_layers,batch_first=False)
out,(hn,cn) = lstm(input)
print(out.size())
print('*'*100)
print(hn.size())
print('*'*100)
print(cn.size())

输出:

torch.Size([1, 2000, 18])
****************************************************************************************************
torch.Size([2, 2000, 18])
****************************************************************************************************
torch.Size([2, 2000, 18])

在这里可以看出两次out的size都是一样的都为[batch_size ,seq_len ,hidden_size],但是在hn和cn的时候发生了变化,变成了[num_layers,seq_len,hidden_size]。这里是需要注意的。

再将其输入变化下

input = torch.randn(batch_size,seq_len,embedding_dim)
# input = torch.unsqueeze(input,dim=0)
lstm = nn.LSTM(input_size=embedding_dim,hidden_size=hidden_size,num_layers=num_layers,batch_first=False)#,batch_first=True
out,(hn,cn) = lstm(input)
print(out.size())
print('*'*100)
print(hn.size())
print('*'*100)
print(cn.size())

输出

torch.Size([10, 2000, 18])
****************************************************************************************************
torch.Size([2, 2000, 18])
****************************************************************************************************
torch.Size([2, 2000, 18])

你可能感兴趣的:(Pytorch,人工智能,深度学习,python,算法)