pad_sequence在pytorch中的使用

>>> from torch.nn.utils.rnn import pad_sequence
>>> input_x =[[1,2,3],[4,5,6,7,8],[8,9]]
>>> norm_data_pad = pad_sequence([torch.from_numpy(np.array(x)) for x in input_x], batch_first=True).float()
>>> norm_data_pad
tensor([[1., 2., 3., 0., 0.],
       	[4., 5., 6., 7., 8.],
        [8., 9., 0., 0., 0.]])
>>> norm_data_pad.dtype
torch.float32

你可能感兴趣的:(pytorch)