Traceback (most recent call last):
File "main.py", line 41, in <module>
loss = models(x, size, y).abs()
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/kaggle/working/model.py", line 236, in forward
emissions = self.bi_lstm_forward(sentence, sentence_lengths)
File "/kaggle/working/model.py", line 227, in bi_lstm_forward
lstm_out, self.hidden = self.lstm(embeds, hidden)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 759, in forward
self.check_forward_args(input, hx, batch_sizes)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 686, in check_forward_args
'Expected hidden[0] size {}, got {}')
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 226, in check_hidden_size
raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))
RuntimeError: Expected hidden[0] size (2, 14, 150), got [2, 64, 150]
大致意思为预期的hidden维度为[2,14,150]
,实际为[2, 64, 150]
'''
config.py
'''
batch_size = 64
epochs = 50
embedding_dim = 300
hidden_dim = 300
'''
model.py
'''
def get_state(self):
c0_encoder = torch.zeros(2, config.batch_size, self.hidden_dim // 2)
### * self.num_directions = 2 if bi
h0_encoder = torch.zeros(2, config.batch_size, self.hidden_dim // 2)
h0_encoder = h0_encoder.to(config.device)
c0_encoder = c0_encoder.to(config.device)
return (h0_encoder, c0_encoder)
我的代码中初始化hidden维度始终是[2, batch_size, hidden_dim//2]
,但我查阅资料发现按batch取数据时,并不是都是设定好的batch_size,当剩余数据少于batch_size时,就直接把批次大小设为剩余量,比如本次报错中,剩余数据只有14条了,因此这一个batch的数据是14条,而不是64(batch_size)条,此时的hidden维度应该为[2, 14, 150]
可以每次将输入的batch数据作为参数传入,动态获取这一batch数据的size
def get_state(self, input):
batch_size = input.size(0)
c0_encoder = torch.zeros(2, batch_size, self.hidden_dim // 2)
### * self.num_directions = 2 if bi
h0_encoder = torch.zeros(2, batch_size, self.hidden_dim // 2)
h0_encoder = h0_encoder.to(config.device)
c0_encoder = c0_encoder.to(config.device)
return (h0_encoder, c0_encoder)