nn.Linear()

官网 nn.Linear()详解

Linear

作用:对输入数据进行线性变换

例子:

import torch
m = torch.nn.Linear(20, 30)
input = torch.randn(128, 20)#输入数据的维度(128,20)
output = m(input)
print(m.weight.shape)
print(m.bias.shape)
print(output.size())
 >>
torch.Size([30, 20])
torch.Size([30])
torch.Size([128, 30])
>>

理解:

线性变换的权重值 weight 和 偏置值 bias 会伴随训练过程不管更新参数,也就是注释中的 learnable ,他们的初始时刻都随机初始化 在区间 :
(,) ,
上面的例子可以看到,输入数据会跟一个权重矩阵 A 相乘,A.shape=[30, 20],偏重为一个一维tensor,长度为[30],权重矩阵相乘得到的128个30维的向量,最后会给每一个向量加上这个偏置误差tensor,所以就对应线性变换公式:

于是nn.Linear()也等价与下面的:

output = torch.mm(input , m.weight.t()) + m.bias  
print(output.size())
>>torch.Size([128, 30])

这个函数是用来设置神经网络中的全连接层的,输入输出都是二维 tensor
in_features:指的是输入的二维tensor的大小,即输入的[batch_size, size]中的size。
out_features:指的是输出的二维tensor的大小,即输出的二维张量的形状为[batch_size,output_size],也代表了该全连接层的神经元个数。

你可能感兴趣的:(nn.Linear())