torch.nn.Linear(in_features, out_features, bias=True)
in_features输入的是二维张量的大小,输入[N, Hin]作为全连接层的输入,N一般指的是batch_size(样本数量),Him指的的是每一个样本的维度大小
out_features指的是输出的二维张量的大小,输出(N,Hout)N一般指的是batch_size(样本数量),Hout指的是每一个样本输出的维度大小,也代表了该全连接层的神经元个数
相当于一个输入为[batch_size, in_features]的张量变换成了[batch_size, out_features]的输出张量
定义一个全连接层输入维度大小10 输出维度大小15,测试其输入输出
>>> fc=nn.Linear(10,15)
>
定义一个input_text作为fc的输入,20个样本 每一个样本的维度大小为10
>>> input_text=torch.randn(20,10)
>>> input_text
tensor([[ 0.5634, 1.0797, -0.2739, -0.5231, -0.3393, 1.1104, -0.2017, -0.0966,
-1.4458, -1.7758],
[-1.4232, 0.6796, 0.1717, -0.4230, -0.2710, 1.2147, 1.7306, 0.0239,
0.7102, 0.1282],
[ 1.0540, 0.4645, 1.3008, -0.2010, 1.6890, 0.3166, -0.1252, -0.0690,
0.3482, 1.2433],
[ 1.4939, -1.0104, -0.1116, -0.0502, 0.7267, 0.3448, -0.4479, -0.4113,
-0.3029, -0.6073],
[-0.5994, -1.5742, -2.0787, 1.1467, -1.2804, 0.6828, -0.9692, -1.6204,
-1.0283, -0.6964],
[ 0.2336, -0.2322, 1.2112, -0.4068, -0.6032, 0.2960, -0.2923, -0.1780,
0.3501, -0.8640],
[ 1.3801, 0.7794, -0.6012, 2.0979, 0.8809, -0.7856, -0.0502, 0.9142,
0.1903, 2.3709],
[-0.3090, -1.3813, -0.1347, -0.9364, -0.7701, -0.6507, 1.5855, 1.2936,
-1.2553, 0.6814],
[ 0.0691, -0.6342, -0.3846, -1.2839, 0.9339, 0.4181, -0.1047, 0.9791,
-1.0878, 2.1031],
[ 0.4759, -0.4706, -0.5551, 0.7296, 0.0382, 0.9609, -1.6298, -0.1432,
-0.5032, -0.3488],
[ 0.0286, 0.9674, 0.0237, -0.9054, 0.1151, 1.0193, 1.1148, 0.1982,
-0.6999, 0.3033],
[-0.1647, 0.3057, -0.0309, 1.9282, 1.7134, 1.1035, -0.8432, -0.5180,
-0.1395, 0.0347],
[-1.5691, 1.1859, -0.9767, -1.3447, 0.6255, -0.4162, 0.5525, 1.4220,
-2.5235, 0.5372],
[-0.5764, -0.2917, -2.1132, 1.9557, -1.4700, 1.6644, 1.4034, 0.3968,
-0.8569, 0.2588],
[-0.7345, 0.9567, -0.6932, -0.1686, -0.2532, -0.7115, 0.4753, -0.9203,
1.1279, -1.2294],
[-0.7426, -0.1470, -0.3995, -0.4880, 0.3165, 1.3640, 0.4278, 1.0734,
-1.9965, -1.7040],
[ 0.3277, -1.4838, -0.6945, -1.0994, -0.4837, -1.9070, 0.1701, -0.2308,
1.0275, 1.3653],
[ 2.0126, 0.5750, -0.3649, -0.2077, -0.0483, 0.4669, -0.5713, 2.7418,
0.2346, 0.0564],
[ 2.1609, 0.4623, -1.4637, -1.6748, 0.7956, -0.2733, -0.3300, -0.1264,
-1.6609, -0.5499],
[ 0.9671, -0.7844, -1.7737, -1.7201, -0.6303, -0.2596, 0.5687, 1.5687,
0.3326, 0.2705]])
将input_text作为输入到fc
>>> result_output=fc(input_text)
打印输出结果
>>> result_output
tensor([[ 1.8107e-01, -9.2258e-02, 2.4591e-01, -1.9315e-01, -3.3488e-01,
1.2616e-01, 4.5071e-01, -2.7091e-01, 1.3229e-01, 8.5799e-01,
-3.0633e-01, 1.4934e-01, 1.5582e+00, 2.7456e-01, -2.0479e-01],
[-4.4143e-01, -1.2819e+00, -9.4922e-01, -8.4364e-01, -4.7017e-01,
-3.4534e-03, 6.7187e-01, 2.8355e-01, 4.0641e-01, 2.6964e-01,
-6.9916e-01, -8.4993e-01, 1.3252e-01, -4.3188e-01, 4.7683e-01],
[-1.2963e+00, -7.0305e-01, 3.1611e-01, 2.5703e-01, -5.6557e-01,
3.3875e-01, -7.9337e-01, 9.0740e-01, 8.6513e-01, 1.9334e-01,
-2.6332e-01, 2.7622e-01, -3.4500e-01, 4.9011e-01, -8.4280e-01],
[-1.1210e-02, 2.5886e-01, 8.8588e-01, 7.9358e-01, 5.5724e-02,
-1.2594e-01, 2.1723e-01, -4.5786e-02, -2.4303e-01, 2.9119e-01,
4.6201e-01, -3.2916e-01, 5.0725e-01, 6.8928e-01, -3.3190e-01],
[ 1.2481e+00, 1.2070e+00, 3.3738e-01, -2.5593e-01, 1.3632e+00,
-3.5514e-01, 1.7132e+00, -9.8362e-01, -2.0212e-01, -6.1235e-02,
9.5519e-01, -9.1147e-01, 6.6596e-01, 4.5335e-01, 4.4601e-01],
[-5.1957e-01, -1.0777e-01, 7.3596e-01, 8.8377e-02, -3.4743e-01,
-1.5220e-01, 3.0113e-01, -3.7542e-01, -2.8111e-01, 5.4499e-03,
3.9416e-03, -1.1461e-01, 3.5127e-01, -1.2291e-01, 1.4345e-01],
[-1.8349e-01, -5.6514e-01, 1.0635e-01, -6.4179e-01, -9.6386e-02,
-1.2538e-01, -9.9904e-01, 1.1812e+00, 5.4885e-01, 1.1835e+00,
-2.2399e-01, 1.1104e+00, -4.6530e-01, 5.3867e-01, -1.5022e+00],
[ 3.7626e-01, 1.5282e-01, 1.2204e-01, -6.1283e-01, -3.0752e-01,
5.5161e-02, 4.2992e-01, 3.8260e-01, -6.8061e-01, 2.0682e-01,
6.3711e-01, -1.2116e-01, 8.7099e-01, 6.0735e-01, 8.2737e-01],
[-8.1442e-01, -1.9702e-02, -2.0397e-01, -3.0586e-01, -3.8486e-01,
6.4765e-01, -3.4066e-01, 6.8476e-01, 3.9929e-01, -5.8264e-02,
6.7580e-01, -4.1936e-02, 4.5742e-01, 1.1757e+00, 2.8947e-01],
[-2.6684e-02, 4.0578e-01, 7.5473e-01, 1.4927e-01, 4.0935e-01,
-1.5429e-01, 2.7960e-01, -3.9922e-01, 1.3923e-01, 4.1013e-01,
2.2092e-01, -1.9759e-02, 3.7310e-01, 6.4977e-01, -1.7508e-01],
[-5.1254e-01, -8.0600e-01, -6.7048e-01, -5.7318e-01, -7.2053e-01,
4.1362e-01, 7.7607e-02, 5.0234e-01, 5.1123e-01, 5.5079e-01,
-2.2140e-01, -3.5574e-02, 9.4809e-01, 1.5752e-01, -4.7720e-02],
[ 7.3687e-03, -3.5115e-01, 6.1108e-01, 4.9131e-02, 7.5905e-01,
-2.5723e-01, 1.0520e-01, 5.2249e-01, 1.2880e+00, 7.2025e-01,
-1.0097e+00, -1.5668e-01, -7.5633e-01, 6.8615e-01, -8.1866e-01],
[ 4.0765e-01, -1.9088e-02, -2.8948e-01, -1.4690e+00, -1.6058e-01,
5.9143e-01, 3.3860e-01, 7.5121e-01, 5.8912e-01, 4.1719e-01,
-5.3752e-01, 5.2774e-02, 9.3624e-01, 1.1130e+00, 3.4052e-01],
[ 1.0823e+00, -4.6622e-01, -9.0878e-01, -1.0677e+00, 4.4945e-01,
-4.2928e-01, 9.1671e-01, 7.0039e-02, 1.3152e-01, 1.4396e+00,
-3.2818e-02, -1.5180e-01, 8.9904e-01, 4.7343e-02, 2.4571e-01],
[ 4.2488e-01, -4.1203e-01, -1.5616e-02, -5.7839e-01, 2.7431e-02,
-3.7477e-01, 1.1616e+00, -3.6302e-01, -3.9705e-01, -5.5184e-02,
-5.1441e-01, -1.0853e+00, 5.5814e-02, -4.7424e-01, -1.6333e-01],
[ 6.4548e-01, -9.6785e-02, 4.4011e-01, -2.2773e-01, 7.7957e-02,
-5.9443e-02, 7.4260e-01, 4.5195e-02, 1.5288e-01, 9.1478e-01,
-6.7894e-01, -3.9288e-01, 1.1380e+00, 8.8407e-01, 6.3056e-01],
[-1.4826e-01, 4.7014e-01, 2.3744e-01, -1.8187e-01, -1.5348e-01,
-4.0711e-02, 5.3915e-01, -1.7642e-01, -1.1896e+00, -8.9747e-01,
1.4008e+00, -8.2545e-01, -4.0002e-03, 2.1347e-01, 3.0044e-01],
[-6.7112e-01, -7.4937e-01, 4.4586e-01, 2.8668e-02, -1.2300e+00,
-1.6951e-01, -1.0214e+00, 1.4511e-01, -8.5860e-01, 1.4472e+00,
-4.0698e-02, 8.9120e-01, 1.0358e+00, 4.7799e-01, -4.3145e-01],
[ 1.5972e-01, 4.2159e-01, 2.3895e-01, 3.2234e-01, -5.4578e-01,
5.0599e-01, 6.8107e-02, 9.7544e-02, -3.9385e-01, 4.5454e-01,
8.3391e-01, -1.6128e-03, 1.8425e+00, 1.0067e+00, -6.3919e-01],
[-4.8341e-02, -1.9460e-01, -2.8181e-01, -1.1646e-01, -8.9199e-01,
-2.8274e-02, 2.9038e-01, -2.7133e-01, -1.5819e+00, 2.9762e-01,
1.0687e+00, -6.4864e-01, 1.3276e+00, 4.1109e-01, 5.7142e-01]],
grad_fn=<AddmmBackward>)
>>> result_output.size()
torch.Size([20, 15])
>>> input_text3=torch.randn(2,15)
>>> result_output3fc(input_text3)
Traceback (most recent call last):
File "" , line 1, in <module>
NameError: name 'result_output3fc' is not defined
>>> result_output3=fc(input_text3)
Traceback (most recent call last):
File "" , line 1, in <module>
File "/usr/local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python3.6/site-packages/torch/nn/modules/linear.py", line 87, in forward
return F.linear(input, self.weight, self.bias)
File "/usr/local/lib/python3.6/site-packages/torch/nn/functional.py", line 1370, in linear
ret = torch.addmm(bias, input, weight.t())
RuntimeError: size mismatch, m1: [2 x 15], m2: [10 x 15] at /pytorch/aten/src/TH/generic/THTensorMath.cpp:136