torch.nn.utils.rnn下面pack_padded_sequence和pad_packed_sequence方法

1. pack_padded_sequence
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch
seq = torch.tensor([[1,2,0], [3,0,0], [4,5,6]])
lens = [2, 1, 3]
packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=False)
packed
Out: PackedSequence(data=tensor([4, 1, 3, 5, 2, 6]), batch_sizes=tensor([3, 2, 1]), sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0]))

这个函数主要做了两件事: pad 和封装,因为在rnn模型中,一般先将batch中的数据按照一个时间步一个时间步喂入模型的,这个包的主要作用就是将按照样本堆叠的数据,抽取出时间步这个维度重新堆叠。

input: pad_sequence 的结果
length:batch 中各个句子的实际长度
batch_first: batch 是否在第一维,上面的例子指定为了True,意思是我们的输入数据第一维是batch。 一般放入lstm或者gru是需要时间步放在第一维的。
enforce_sorted:如果是 True ,则输入应该是按长度降序排序的序列。如果是 False ,会在函数内部进行排序。默认值为 True。
需要注意的是,默认条件下,我们必须把输入数据按照序列长度从大到小排列后才能送入 pack_padded_sequence ,否则会报错。

packed = pack_padded_sequence(seq, lens)
Traceback (most recent call last):
  File "/Users/daxu/.conda/envs/py38/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3441, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "", line 1, in <module>
    packed = pack_padded_sequence(seq, lens)
  File "/Users/daxu/.conda/envs/py38/lib/python3.8/site-packages/torch/nn/utils/rnn.py", line 245, in pack_padded_sequence
    _VF._pack_padded_sequence(input, lengths, batch_first)
RuntimeError: `lengths` array must be sorted in decreasing order when `enforce_sorted` is True. You can pass `enforce_sorted=False` to pack_padded_sequence and/or pack_sequence to sidestep this requirement if you do not need ONNX exportability.
2. pad_packed_sequence
seq_unpacked, lens_unpacked = pad_packed_sequence(packed, batch_first=True)
seq_unpacked
Out[55]: 
tensor([[1, 2, 0],
        [3, 0, 0],
        [4, 5, 6]])
lens_unpacked
Out[56]: tensor([2, 1, 3])

pad_packed_sequence方法是pack_padded_sequence的逆操作

input:填充的可变长度序列批次。
batch_first:如果为 True ,输出将采用 B x T x * 格式。
padding_value:填充元素的值。
total_length:如果不是 None ,输出将被填充为长度为 total_length 。如果 total_length 小于 sequence 中的最大序列长度, sequence 方法将抛出 ValueError 。

你可能感兴趣的:(pytorch,rnn,深度学习,人工智能)