当一个Dataset处理完数据后,需要加载时,希望在一个mini batch内pad数据,把数据pad成这个批内最大的长度,减小不必要的显存消耗。
torch给提供了这样的函数,在torch.nn.utils.rnn.pad_sequence
函数。
# 函数返回一个T x B x * 或 B x T x *的一个tensor,当batch_first=True时,B在前面。
# 需要pad的tensor的维度放在第一个维度上。
"""
This function returns a Tensor of size ``T x B x *`` or ``B x T x *``
where `T` is the length of the longest sequence. This function assumes
trailing dimensions and type of all the Tensors in sequences are same.
"""
>>> from torch.nn.utils.rnn import pad_sequence
>>> a = torch.ones(25, 300)
>>> b = torch.ones(22, 300)
>>> c = torch.ones(15, 300)
>>> pad_sequence([a, b, c]).size()
torch.Size([25, 3, 300])
一维tensor的pad:
>>> a
tensor([1., 1., 1.])
>>> b
tensor([1., 1., 1., 1., 1., 1.])
>>> c
tensor([1., 1., 1., 1., 1., 1., 1., 1.])
>>> pad_sequence([a,b,c],batch_first=True).size()
torch.Size([3, 8])
>>> pad_sequence([a,b,c],batch_first=True)
tensor([[1., 1., 1., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1., 1., 0., 0.],
[1., 1., 1., 1., 1., 1., 1., 1.]])