在每个批内pad一个tensor

当一个Dataset处理完数据后,需要加载时,希望在一个mini batch内pad数据,把数据pad成这个批内最大的长度,减小不必要的显存消耗。

torch给提供了这样的函数,在torch.nn.utils.rnn.pad_sequence函数。

# 函数返回一个T x B x * 或 B x T x *的一个tensor,当batch_first=True时,B在前面。
# 需要pad的tensor的维度放在第一个维度上。
"""
This function returns a Tensor of size ``T x B x *`` or ``B x T x *``
        where `T` is the length of the longest sequence. This function assumes
        trailing dimensions and type of all the Tensors in sequences are same.
"""
>>> from torch.nn.utils.rnn import pad_sequence
>>> a = torch.ones(25, 300)
>>> b = torch.ones(22, 300)
>>> c = torch.ones(15, 300)
>>> pad_sequence([a, b, c]).size()
torch.Size([25, 3, 300])

一维tensor的pad:

>>> a
tensor([1., 1., 1.])
>>> b
tensor([1., 1., 1., 1., 1., 1.])
>>> c
tensor([1., 1., 1., 1., 1., 1., 1., 1.])
>>> pad_sequence([a,b,c],batch_first=True).size()
torch.Size([3, 8])
>>> pad_sequence([a,b,c],batch_first=True)
tensor([[1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

你可能感兴趣的:(pytorch)