PyTorch的nn.Linear()的理解

PyTorch的nn.Linear()的理解_第1张图片

torch.nn.Linear(in_features, out_features, bias=True)

in_features输入的是二维张量的大小,输入[N, Hin]作为全连接层的输入,N一般指的是batch_size(样本数量),Him指的的是每一个样本的维度大小

out_features指的是输出的二维张量的大小,输出(N,Hout)N一般指的是batch_size(样本数量),Hout指的是每一个样本输出的维度大小,也代表了该全连接层的神经元个数

相当于一个输入为[batch_size, in_features]的张量变换成了[batch_size, out_features]的输出张量

定义一个全连接层输入维度大小10 输出维度大小15,测试其输入输出


>>> fc=nn.Linear(10,15)
>
定义一个input_text作为fc的输入,20个样本 每一个样本的维度大小为10

>>> input_text=torch.randn(20,10)
>>> input_text
tensor([[ 0.5634,  1.0797, -0.2739, -0.5231, -0.3393,  1.1104, -0.2017, -0.0966,
         -1.4458, -1.7758],
        [-1.4232,  0.6796,  0.1717, -0.4230, -0.2710,  1.2147,  1.7306,  0.0239,
          0.7102,  0.1282],
        [ 1.0540,  0.4645,  1.3008, -0.2010,  1.6890,  0.3166, -0.1252, -0.0690,
          0.3482,  1.2433],
        [ 1.4939, -1.0104, -0.1116, -0.0502,  0.7267,  0.3448, -0.4479, -0.4113,
         -0.3029, -0.6073],
        [-0.5994, -1.5742, -2.0787,  1.1467, -1.2804,  0.6828, -0.9692, -1.6204,
         -1.0283, -0.6964],
        [ 0.2336, -0.2322,  1.2112, -0.4068, -0.6032,  0.2960, -0.2923, -0.1780,
          0.3501, -0.8640],
        [ 1.3801,  0.7794, -0.6012,  2.0979,  0.8809, -0.7856, -0.0502,  0.9142,
          0.1903,  2.3709],
        [-0.3090, -1.3813, -0.1347, -0.9364, -0.7701, -0.6507,  1.5855,  1.2936,
         -1.2553,  0.6814],
        [ 0.0691, -0.6342, -0.3846, -1.2839,  0.9339,  0.4181, -0.1047,  0.9791,
         -1.0878,  2.1031],
        [ 0.4759, -0.4706, -0.5551,  0.7296,  0.0382,  0.9609, -1.6298, -0.1432,
         -0.5032, -0.3488],
        [ 0.0286,  0.9674,  0.0237, -0.9054,  0.1151,  1.0193,  1.1148,  0.1982,
         -0.6999,  0.3033],
        [-0.1647,  0.3057, -0.0309,  1.9282,  1.7134,  1.1035, -0.8432, -0.5180,
         -0.1395,  0.0347],
        [-1.5691,  1.1859, -0.9767, -1.3447,  0.6255, -0.4162,  0.5525,  1.4220,
         -2.5235,  0.5372],
        [-0.5764, -0.2917, -2.1132,  1.9557, -1.4700,  1.6644,  1.4034,  0.3968,
         -0.8569,  0.2588],
        [-0.7345,  0.9567, -0.6932, -0.1686, -0.2532, -0.7115,  0.4753, -0.9203,
          1.1279, -1.2294],
        [-0.7426, -0.1470, -0.3995, -0.4880,  0.3165,  1.3640,  0.4278,  1.0734,
         -1.9965, -1.7040],
        [ 0.3277, -1.4838, -0.6945, -1.0994, -0.4837, -1.9070,  0.1701, -0.2308,
          1.0275,  1.3653],
        [ 2.0126,  0.5750, -0.3649, -0.2077, -0.0483,  0.4669, -0.5713,  2.7418,
          0.2346,  0.0564],
        [ 2.1609,  0.4623, -1.4637, -1.6748,  0.7956, -0.2733, -0.3300, -0.1264,
         -1.6609, -0.5499],
        [ 0.9671, -0.7844, -1.7737, -1.7201, -0.6303, -0.2596,  0.5687,  1.5687,
          0.3326,  0.2705]])
   
将input_text作为输入到fc

>>> result_output=fc(input_text)
打印输出结果

>>> result_output
tensor([[ 1.8107e-01, -9.2258e-02,  2.4591e-01, -1.9315e-01, -3.3488e-01,
          1.2616e-01,  4.5071e-01, -2.7091e-01,  1.3229e-01,  8.5799e-01,
         -3.0633e-01,  1.4934e-01,  1.5582e+00,  2.7456e-01, -2.0479e-01],
        [-4.4143e-01, -1.2819e+00, -9.4922e-01, -8.4364e-01, -4.7017e-01,
         -3.4534e-03,  6.7187e-01,  2.8355e-01,  4.0641e-01,  2.6964e-01,
         -6.9916e-01, -8.4993e-01,  1.3252e-01, -4.3188e-01,  4.7683e-01],
        [-1.2963e+00, -7.0305e-01,  3.1611e-01,  2.5703e-01, -5.6557e-01,
          3.3875e-01, -7.9337e-01,  9.0740e-01,  8.6513e-01,  1.9334e-01,
         -2.6332e-01,  2.7622e-01, -3.4500e-01,  4.9011e-01, -8.4280e-01],
        [-1.1210e-02,  2.5886e-01,  8.8588e-01,  7.9358e-01,  5.5724e-02,
         -1.2594e-01,  2.1723e-01, -4.5786e-02, -2.4303e-01,  2.9119e-01,
          4.6201e-01, -3.2916e-01,  5.0725e-01,  6.8928e-01, -3.3190e-01],
        [ 1.2481e+00,  1.2070e+00,  3.3738e-01, -2.5593e-01,  1.3632e+00,
         -3.5514e-01,  1.7132e+00, -9.8362e-01, -2.0212e-01, -6.1235e-02,
          9.5519e-01, -9.1147e-01,  6.6596e-01,  4.5335e-01,  4.4601e-01],
        [-5.1957e-01, -1.0777e-01,  7.3596e-01,  8.8377e-02, -3.4743e-01,
         -1.5220e-01,  3.0113e-01, -3.7542e-01, -2.8111e-01,  5.4499e-03,
          3.9416e-03, -1.1461e-01,  3.5127e-01, -1.2291e-01,  1.4345e-01],
        [-1.8349e-01, -5.6514e-01,  1.0635e-01, -6.4179e-01, -9.6386e-02,
         -1.2538e-01, -9.9904e-01,  1.1812e+00,  5.4885e-01,  1.1835e+00,
         -2.2399e-01,  1.1104e+00, -4.6530e-01,  5.3867e-01, -1.5022e+00],
        [ 3.7626e-01,  1.5282e-01,  1.2204e-01, -6.1283e-01, -3.0752e-01,
          5.5161e-02,  4.2992e-01,  3.8260e-01, -6.8061e-01,  2.0682e-01,
          6.3711e-01, -1.2116e-01,  8.7099e-01,  6.0735e-01,  8.2737e-01],
        [-8.1442e-01, -1.9702e-02, -2.0397e-01, -3.0586e-01, -3.8486e-01,
          6.4765e-01, -3.4066e-01,  6.8476e-01,  3.9929e-01, -5.8264e-02,
          6.7580e-01, -4.1936e-02,  4.5742e-01,  1.1757e+00,  2.8947e-01],
        [-2.6684e-02,  4.0578e-01,  7.5473e-01,  1.4927e-01,  4.0935e-01,
         -1.5429e-01,  2.7960e-01, -3.9922e-01,  1.3923e-01,  4.1013e-01,
          2.2092e-01, -1.9759e-02,  3.7310e-01,  6.4977e-01, -1.7508e-01],
        [-5.1254e-01, -8.0600e-01, -6.7048e-01, -5.7318e-01, -7.2053e-01,
          4.1362e-01,  7.7607e-02,  5.0234e-01,  5.1123e-01,  5.5079e-01,
         -2.2140e-01, -3.5574e-02,  9.4809e-01,  1.5752e-01, -4.7720e-02],
        [ 7.3687e-03, -3.5115e-01,  6.1108e-01,  4.9131e-02,  7.5905e-01,
         -2.5723e-01,  1.0520e-01,  5.2249e-01,  1.2880e+00,  7.2025e-01,
         -1.0097e+00, -1.5668e-01, -7.5633e-01,  6.8615e-01, -8.1866e-01],
        [ 4.0765e-01, -1.9088e-02, -2.8948e-01, -1.4690e+00, -1.6058e-01,
          5.9143e-01,  3.3860e-01,  7.5121e-01,  5.8912e-01,  4.1719e-01,
         -5.3752e-01,  5.2774e-02,  9.3624e-01,  1.1130e+00,  3.4052e-01],
        [ 1.0823e+00, -4.6622e-01, -9.0878e-01, -1.0677e+00,  4.4945e-01,
         -4.2928e-01,  9.1671e-01,  7.0039e-02,  1.3152e-01,  1.4396e+00,
         -3.2818e-02, -1.5180e-01,  8.9904e-01,  4.7343e-02,  2.4571e-01],
        [ 4.2488e-01, -4.1203e-01, -1.5616e-02, -5.7839e-01,  2.7431e-02,
         -3.7477e-01,  1.1616e+00, -3.6302e-01, -3.9705e-01, -5.5184e-02,
         -5.1441e-01, -1.0853e+00,  5.5814e-02, -4.7424e-01, -1.6333e-01],
        [ 6.4548e-01, -9.6785e-02,  4.4011e-01, -2.2773e-01,  7.7957e-02,
         -5.9443e-02,  7.4260e-01,  4.5195e-02,  1.5288e-01,  9.1478e-01,
         -6.7894e-01, -3.9288e-01,  1.1380e+00,  8.8407e-01,  6.3056e-01],
        [-1.4826e-01,  4.7014e-01,  2.3744e-01, -1.8187e-01, -1.5348e-01,
         -4.0711e-02,  5.3915e-01, -1.7642e-01, -1.1896e+00, -8.9747e-01,
          1.4008e+00, -8.2545e-01, -4.0002e-03,  2.1347e-01,  3.0044e-01],
        [-6.7112e-01, -7.4937e-01,  4.4586e-01,  2.8668e-02, -1.2300e+00,
         -1.6951e-01, -1.0214e+00,  1.4511e-01, -8.5860e-01,  1.4472e+00,
         -4.0698e-02,  8.9120e-01,  1.0358e+00,  4.7799e-01, -4.3145e-01],
        [ 1.5972e-01,  4.2159e-01,  2.3895e-01,  3.2234e-01, -5.4578e-01,
          5.0599e-01,  6.8107e-02,  9.7544e-02, -3.9385e-01,  4.5454e-01,
          8.3391e-01, -1.6128e-03,  1.8425e+00,  1.0067e+00, -6.3919e-01],
        [-4.8341e-02, -1.9460e-01, -2.8181e-01, -1.1646e-01, -8.9199e-01,
         -2.8274e-02,  2.9038e-01, -2.7133e-01, -1.5819e+00,  2.9762e-01,
          1.0687e+00, -6.4864e-01,  1.3276e+00,  4.1109e-01,  5.7142e-01]],
       grad_fn=<AddmmBackward>)
>>> result_output.size()
torch.Size([20, 15])
>>> input_text3=torch.randn(2,15)
>>> result_output3fc(input_text3)
Traceback (most recent call last):
  File "", line 1, in <module>
NameError: name 'result_output3fc' is not defined
>>> result_output3=fc(input_text3)
Traceback (most recent call last):
  File "", line 1, in <module>
  File "/usr/local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/torch/nn/modules/linear.py", line 87, in forward
    return F.linear(input, self.weight, self.bias)
  File "/usr/local/lib/python3.6/site-packages/torch/nn/functional.py", line 1370, in linear
    ret = torch.addmm(bias, input, weight.t())
RuntimeError: size mismatch, m1: [2 x 15], m2: [10 x 15] at /pytorch/aten/src/TH/generic/THTensorMath.cpp:136

你可能感兴趣的:(NLP,Pytorch,自然语言处理,nlp)