【pytorch框架学习】nn.Embedding中的padding_idx用法示意

import torch
import torch.nn as nn
embedding1 = nn.Embedding(10,3)
embedding1.weight
Parameter containing:
tensor([[-0.9116,  0.5195, -1.3509],
        [ 0.5670,  0.8024, -0.0373],
        [-0.8223, -1.2181, -0.6713],
        [-1.2734, -1.0591, -1.1202],
        [-0.4734,  1.8297,  0.3880],
        [ 0.5687,  0.3136,  0.7541],
        [ 1.0070, -0.0197, -0.1715],
        [ 2.1003,  0.6229,  0.6720],
        [-0.1729, -0.6555,  0.2904],
        [-1.6015, -1.3011, -0.5837]], requires_grad=True)
embedding2 = nn.Embedding(10,3,padding_idx=0)
embedding2.weight
Parameter containing:
tensor([[ 0.0000,  0.0000,  0.0000],
        [-0.5784, -1.5044, -1.7400],
        [-1.1197,  0.8234, -0.6458],
        [ 0.8204,  2.0259, -0.9619],
        [ 0.1317, -0.3696, -1.6996],
        [-0.2763, -0.3568,  0.2973],
        [-1.2864, -0.2396,  1.3876],
        [-1.6487, -0.0096,  0.1984],
        [-0.2213, -1.0257, -0.6359],
        [ 0.2354, -0.7799, -0.3288]], requires_grad=True)
embedding3 = nn.Embedding(10,3,padding_idx=2)
embedding3.weight
Parameter containing:
tensor([[-1.0108, -1.5298, -0.3603],
        [-1.1312,  1.4528, -0.7718],
        [ 0.0000,  0.0000,  0.0000],
        [-0.8255, -0.4083,  0.7372],
        [-0.8608,  0.2809,  0.1835],
        [-0.6224, -0.1390, -0.7797],
        [-0.6382,  0.6341,  0.2778],
        [-0.6328,  0.2855, -0.3784],
        [-0.8825, -0.2000, -1.2097],
        [ 0.9235,  0.5388,  0.8158]], requires_grad=True)
embedding4 = nn.Embedding(10,3,padding_idx=10-1)
embedding4.weight
Parameter containing:
tensor([[-1.4354,  0.8168,  0.4477],
        [-0.4925, -0.3006,  0.7584],
        [-0.2400,  1.0259, -0.5391],
        [-0.5411,  0.9602,  0.1372],
        [-0.6848,  0.0278, -0.1112],
        [-0.2092, -1.8230,  1.0283],
        [ 0.5441,  0.6374, -0.9901],
        [ 0.1115,  0.2792, -0.1808],
        [-3.7124, -0.6969, -0.6027],
        [ 0.0000,  0.0000,  0.0000]], requires_grad=True)
input = torch.tensor([[1,9,9,9],[9,3,9,6]])
embedding4(input)
tensor([[[-0.4925, -0.3006,  0.7584],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000],
         [-0.5411,  0.9602,  0.1372],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.5441,  0.6374, -0.9901]]], grad_fn=)

你可能感兴趣的:(pytorch)