nn.utils.rnn.pack_padded_sequence , nn.utils.rnn.pad_packed_sequence 以及
参考https://www.cnblogs.com/lindaxin/p/8052043.html
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import utils as nn_utils
import numpy as np
import logging
a = torch.randn(3,4,5)
print(a)
length = [4,2,1]
embedded = nn.utils.rnn.pack_padded_sequence(a, length, batch_first=True)
print(embedded)
rnn = nn.RNN(5, 3, 1, batch_first=True, bidirectional=True)
output, hidden = rnn(embedded)
output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
#output自动将前向与后向拼接了
print(output)
# hidden没有拼接,注意output中的有些元素与hidden相同
print(hidden)
输出:
tensor([[[ 0.0502, -0.3737, 1.3803, 0.7827, -1.9780],
[-1.1792, 1.5573, -0.0832, 0.7857, -0.7991],
[-0.7319, -1.4783, -0.2471, -1.0115, 1.0919],
[ 0.4865, 0.1716, 0.0242, 0.2300, -1.1641]],
[[ 0.2066, -0.5143, -1.1541, -1.4135, 0.0579],
[-1.8993, 1.3314, 1.5115, 0.3241, -1.2116],
[-0.3505, 2.4541, 0.0762, -0.8498, -0.1097],
[ 1.7048, -1.1850, -1.7106, -0.1328, -0.1891]],
[[-0.1876, -1.4371, -1.8536, 1.4101, 0.5216],
[ 0.3570, 0.8627, -0.1122, -1.2451, -0.9119],
[ 2.3923, 1.0202, 1.5474, -0.7053, 0.5547],
[-0.7154, 1.3512, -1.8014, 0.3778, 0.9495]]])
PackedSequence(data=tensor([[ 0.0502, -0.3737, 1.3803, 0.7827, -1.9780],
[ 0.2066, -0.5143, -1.1541, -1.4135, 0.0579],
[-0.1876, -1.4371, -1.8536, 1.4101, 0.5216],
[-1.1792, 1.5573, -0.0832, 0.7857, -0.7991],
[-1.8993, 1.3314, 1.5115, 0.3241, -1.2116],
[-0.7319, -1.4783, -0.2471, -1.0115, 1.0919],
[ 0.4865, 0.1716, 0.0242, 0.2300, -1.1641]]), batch_sizes=tensor([3, 2, 1, 1]))
tensor([[[-0.6479, -0.3464, -0.4421, 0.6213, -0.9616, 0.0461],
[ 0.6239, 0.0384, -0.4875, -0.5125, 0.6439, 0.7058],
[ 0.6740, 0.9082, 0.8267, 0.6974, -0.0902, -0.7959],
[ 0.7346, -0.6262, -0.1682, 0.7487, -0.5633, 0.6784]],
[[ 0.7597, -0.0259, 0.3454, 0.8721, -0.0162, -0.5046],
[ 0.4286, 0.3801, -0.6889, -0.8130, -0.3739, -0.4051],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[ 0.5331, 0.8544, 0.9515, 0.9457, 0.1661, 0.3958],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]],
grad_fn=)
tensor([[[ 0.7346, -0.6262, -0.1682],
[ 0.4286, 0.3801, -0.6889],
[ 0.5331, 0.8544, 0.9515]],
[[ 0.6213, -0.9616, 0.0461],
[ 0.8721, -0.0162, -0.5046],
[ 0.9457, 0.1661, 0.3958]]], grad_fn=)