Pytorch LSTM网络初始化hidden维度报错:RuntimeError: Expected hidden[0] size (2, 14, 150), got [2, 64, 150]

如下为报错信息

Traceback (most recent call last):
  File "main.py", line 41, in <module>
    loss = models(x, size, y).abs()
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/kaggle/working/model.py", line 236, in forward
    emissions = self.bi_lstm_forward(sentence, sentence_lengths)
  File "/kaggle/working/model.py", line 227, in bi_lstm_forward
    lstm_out, self.hidden = self.lstm(embeds, hidden)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 759, in forward
    self.check_forward_args(input, hx, batch_sizes)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 686, in check_forward_args
    'Expected hidden[0] size {}, got {}')
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 226, in check_hidden_size
    raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))
RuntimeError: Expected hidden[0] size (2, 14, 150), got [2, 64, 150]

大致意思为预期的hidden维度为[2,14,150],实际为[2, 64, 150]

如下为模型代码中获取hidden状态的方法及参数

'''
config.py
'''
batch_size = 64
epochs = 50
embedding_dim = 300
hidden_dim = 300
'''
model.py
'''
def get_state(self):
    c0_encoder = torch.zeros(2, config.batch_size, self.hidden_dim // 2)
    ### * self.num_directions = 2 if bi
    h0_encoder = torch.zeros(2, config.batch_size, self.hidden_dim // 2)
    h0_encoder = h0_encoder.to(config.device)
    c0_encoder = c0_encoder.to(config.device)
    return (h0_encoder, c0_encoder)

我的代码中初始化hidden维度始终是[2, batch_size, hidden_dim//2],但我查阅资料发现按batch取数据时,并不是都是设定好的batch_size,当剩余数据少于batch_size时,就直接把批次大小设为剩余量,比如本次报错中,剩余数据只有14条了,因此这一个batch的数据是14条,而不是64(batch_size)条,此时的hidden维度应该为[2, 14, 150]

修改代码

可以每次将输入的batch数据作为参数传入,动态获取这一batch数据的size

def get_state(self, input):
	batch_size = input.size(0)
    c0_encoder = torch.zeros(2, batch_size, self.hidden_dim // 2)
    ### * self.num_directions = 2 if bi
    h0_encoder = torch.zeros(2, batch_size, self.hidden_dim // 2)
    h0_encoder = h0_encoder.to(config.device)
    c0_encoder = c0_encoder.to(config.device)
    return (h0_encoder, c0_encoder)

你可能感兴趣的:(数据分析与机器学习,Python,pytorch,lstm,深度学习)