该函数用padding_value来填充一个可变长度的张量列表。将长度较短的序列填充为和最长序列相同的长度。
一句话就是:填充句子到相同长度。
参数说明:
True
,output形状为B × T × ∗ ,否则为T × B × ∗ ,默认情况为False
。其中B BB为批次大小,T TT为填充后每个序列的长度。输出:
如果 batch_first 是 False
,张量的形状为T × B × ∗ 。否则,张量的形状为B × T × ∗ 。
举个栗子:
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pack_sequence, pad_packed_sequence
content = [[12, 12, 11, 1, 21, 7, 7], [12, 12, 11, 1, 21], [12, 12, 11, 21]]
DATA = list(map(lambda x: torch.tensor(x), content))
p1 = pad_sequence(DATA, batch_first=True)
print(p1)
压紧(pack)一个包含可变长度的填充序列的张量,在使用pad_sequence函数进行填充的时候,产生了冗余,因此需要对其进行pack。
参数说明:
True
,则输入的形状为B × T × ∗,我一般将其设置为True
。True
,则参数lenghts
为按长度递减排序的序列,这样的话输入的input也需要进行排序。我一般将其设置为False
。如果为False
输入将被无条件地排序。from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pack_sequence, pad_packed_sequence
content = [[12, 12, 11, 1, 21, 7, 7], [12, 12, 11, 1, 21], [12, 12, 11, 21]]
DATA = list(map(lambda x: torch.tensor(x), content))
print(content, DATA)
p1 = pad_sequence(DATA, batch_first=True)
print(p1)
p2 = pack_padded_sequence(p1, [7, 5, 4], batch_first=True, enforce_sorted=False)
print(p2)
函数对返回的结果进行填充以恢复为原来的形状。
参数说明:
True
,输出形状为B × T × ∗ B \times T \times *B×T×∗。输出:
包含填充序列的张量的元组,以及包含批次中每个序列的长度列表的张量。
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pack_sequence, pad_packed_sequence
content = [[12, 12, 11, 1, 21, 7, 7], [12, 12, 11, 1, 21], [12, 12, 11, 21]]
DATA = list(map(lambda x: torch.tensor(x), content))
print(content, DATA)
p1 = pad_sequence(DATA, batch_first=True)
print(p1)
p2 = pack_padded_sequence(p1, [7, 5, 4], batch_first=True, enforce_sorted=False)
print(p2)
p3 = pad_packed_sequence(p2, batch_first=True)
print(p3)
sequences (list[Tensor]): A list of sequences of decreasing length.enforce_sorted (bool, optional)
: if True
, checks that the input contains sequences sorted by length in a decreasing order. If False
, this condition is not checked. Default: True
.
from torch.nn.utils.rnn import pack_sequence
import torch
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5])
c = torch.tensor([6])
print(pack_sequence([a, b, c], True))