torch.nn.utils.rnn.PackedSequence
这个类的实例不能手动创建。它们只能被pack_padded_sequence() 实例化。
torch.nn.utils.rnn.pack_padded_sequence()**
输入:
input: [seq_length x batch_size x input_size] 或 [batch_size x seq_length x input_size],input中的seq要按照长度递减的方式排列。
lengths: seq的长度列表,是一个递减的列表,与input里的seq长度对应。ie. [5,4,1]
batch_first: bool变量,当它为True时,表示input为这种输入形式[batch_size x seq_length x input_size],否则为另一种。
输出:
一个PackedSequence对象,包含一个Variable类型的data,和链表类型的batch_sizes。
batch的每一个元素,代表data中,多少行为一个batch。
例如:
输入为
input
Variable containing:
(0 ,.,.) =
1
2
3
(1 ,.,.) =
1
0
0
[torch.FloatTensor of size 2x3x1]
lengths = [3, 1]
为了实现压缩编码,即把填充去除。我们最终的输出为
PackedSequence(data=Variable containing:
1
1
2
3
[torch.FloatTensor of size 4x1]
, batch_sizes=[2, 1, 1])
这就表明,前两个1属于一个batch,后面两个分别属于不同的batch。换句话说,从batch_sizes可以看出,两个seq的长度分别为1,3。后面的module或function可以根据batch_sizes读取对应的数据。
这里我们以上面的输入为例,研究该函数到底是怎么实现数据压缩的。
def pack_padded_sequence(input, lengths, batch_first=False):
# juge the length is > 0
if lengths[-1] <= 0:
raise ValueError("length of all samples has to be greater than 0, "
"but found an element in 'lengths' that is <=0")
# change the input into the shape of [seq_length x batch_size x input_size]
# here input is [3, 2, 1]
if batch_first:
input = input.transpose(0, 1)
steps = []
batch_sizes = []
# get the reversed iterator of the lengths
lengths_iter = reversed(lengths)
# here current_length == 1
current_length = next(lengths_iter)
batch_size = input.size(1)
if len(lengths) != batch_size:
raise ValueError("lengths array has incorrect size")
# here 1 indicate the 'step' start from 1
for step, step_value in enumerate(input, 1):
"""
step_value == 1
1
[torch.FloatTensor of size 2x1]
"""
steps.append(step_value[:batch_size])
batch_sizes.append(batch_size)
# juge if step to the end of a short seq
while step == current_length:
try:
new_length = next(lengths_iter)
except StopIteration:
current_length = None
break
# check the lengths if is a decrasing list
if current_length > new_length: # remember that new_length is the preceding length in the array
raise ValueError("lengths array has to be sorted in decreasing order")
# already step over a short seq, so the number of the batch should minus 1.
batch_size -= 1
current_length = new_length
if current_length is None:
break
# here concat the list along the dim0.
return PackedSequence(torch.cat(steps), batch_sizes)
nn.utils.rnn.pad_packed_sequence()
这就是上一个函数的逆操作。输入是一个PackedSequence对象,包含batch_sizes,可以根据其对其中的data进行解耦。