在Tensorflow 1.12中如果LSTM输入的序列是变长的话,有dynamic_rnn()
或bidirection_dynamic_rnn()
方法来处理,但是在pytorch中怎么处理呢?
在pytorch中也有相对应的方法在torch.nn.utils
包中的pack_padded_sequence()
和pad_packed_sequence()
用来处理变长序列的问题。pack是压缩的意思,pad是填充的意思。pytorch的处理是先根据seq_len来压缩输入,经过LSTM后再填充。所以这两个方法都要知道padding的index。
# -*- coding: utf-8 -*-
import torch
from torch import nn
from torch.nn import LSTM
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
X = torch.tensor([[1, 2, 3],
[1, 2, 0],
[3, 0, 0],
[2, 1, 0]])
seq_len = torch.tensor([3, 2, 1, 2])
print('X shape', X.shape)
vocab_size = 4
embedding_dim = 2
hidden_size = 6
batch_size = 4
# tell word embedding the padding_idx,
# whenever encounter the padding idx, get the vector with padding idx
word_embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=0)
word_vectors = word_embedding(X)
print(word_vectors)
lstm = LSTM(input_size=embedding_dim,
hidden_size=hidden_size // 2,
num_layers=1,
bidirectional=True,
batch_first=True)
packed_X = pack_padded_sequence(word_vectors, seq_len, batch_first=True, enforce_sorted=False)
print('--->pack padded sequence:{}'.format(packed_X))
# num_layers* num_directions=1*2=2
h0 = (torch.randn(2, batch_size, hidden_size // 2),
torch.randn(2, batch_size, hidden_size // 2))
outputs, h = lstm(packed_X, h0)
print('after lstm outputs: {}'.format(outputs))
pad_packed_X = pad_packed_sequence(outputs, batch_first=True, padding_value=0.0)
# pad_packed_X是一个tuple,
print('---->pad_packed:{}'.format(pad_packed_X))
# [batch_size, max_seq_len, hidden_size]
print('---->pad_outputs:{}'.format(pad_packed_X[0].shape))
print('---->seq_len:{}'.format(pad_packed_X[1].shape))
X shape torch.Size([4, 3])
tensor([[[ 0.7049, -0.6178],
[-2.0429, 0.7651],
[-0.4018, 0.5503]],
[[ 0.7049, -0.6178],
[-2.0429, 0.7651],
[ 0.0000, 0.0000]],
[[-0.4018, 0.5503],
[ 0.0000, 0.0000],
[ 0.0000, 0.0000]],
[[-2.0429, 0.7651],
[ 0.7049, -0.6178],
[ 0.0000, 0.0000]]], grad_fn=<EmbeddingBackward>)
--->pack padded sequence:PackedSequence(data=tensor([[ 0.7049, -0.6178],
[ 0.7049, -0.6178],
[-2.0429, 0.7651],
[-0.4018, 0.5503],
[-2.0429, 0.7651],
[-2.0429, 0.7651],
[ 0.7049, -0.6178],
[-0.4018, 0.5503]], grad_fn=<PackPaddedSequenceBackward>), batch_sizes=tensor([4, 3, 1]), sorted_indices=tensor([0, 1, 3, 2]), unsorted_indices=tensor([0, 1, 3, 2]))
after lstm outputs: PackedSequence(data=tensor([[ 0.0594, -0.1153, 0.2028, 0.0240, -0.3177, 0.1364],
[ 0.1639, -0.0723, 0.2713, 0.0129, -0.3103, -0.0151],
[ 0.0144, 0.3130, 0.0122, -0.1323, -0.0634, 0.0550],
[ 0.2420, -0.3547, 0.1539, 0.0024, -0.4087, 0.3998],
[ 0.0393, 0.2744, 0.0270, -0.0392, -0.0846, 0.2874],
[ 0.0752, 0.3018, 0.0351, -0.0354, -0.2119, -0.5068],
[ 0.2033, 0.1408, 0.0661, -0.4325, -0.0149, 0.0146],
[ 0.0668, 0.3305, 0.0304, 0.2233, -0.0265, 0.1531]],
grad_fn=<CatBackward>), batch_sizes=tensor([4, 3, 1]), sorted_indices=tensor([0, 1, 3, 2]), unsorted_indices=tensor([0, 1, 3, 2]))
---->pad_packed:(tensor([[[ 0.0594, -0.1153, 0.2028, 0.0240, -0.3177, 0.1364],
[ 0.0393, 0.2744, 0.0270, -0.0392, -0.0846, 0.2874],
[ 0.0668, 0.3305, 0.0304, 0.2233, -0.0265, 0.1531]],
[[ 0.1639, -0.0723, 0.2713, 0.0129, -0.3103, -0.0151],
[ 0.0752, 0.3018, 0.0351, -0.0354, -0.2119, -0.5068],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[ 0.2420, -0.3547, 0.1539, 0.0024, -0.4087, 0.3998],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[ 0.0144, 0.3130, 0.0122, -0.1323, -0.0634, 0.0550],
[ 0.2033, 0.1408, 0.0661, -0.4325, -0.0149, 0.0146],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]],
grad_fn=<IndexSelectBackward>), tensor([3, 2, 1, 2]))
---->pad_outputs:torch.Size([4, 3, 6])
---->seq_len:torch.Size([4])
Process finished with exit code 0
参考 :
https://blog.csdn.net/So_that/article/details/94731614
https://gist.github.com/dolaameng/8918fefc05d0589f66279eab763d1e13