nn.Linear()函数详解

nn.Linear()函数详解

torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)[原文地址](Linear — PyTorch 1.12 documentation)

其中的参数:

  • in_features – 每个输入样本的大小。
  • out_features – 每个输出样本的大小。
  • bias – 如果设置为False,该层将不会学习附加偏差。默认为:True。

shape:

  • Input:(#,IN),其中#表示表示任意大小的维度。IN表示in_features
  • Outout:(#,OUT),其中除了最后一个OUT外,其余的维度都和输入的shape相同,OUT表示out_features

实例:

>>> m = nn.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])

你可能感兴趣的:(pytorch学习,深度学习,pytorch,python)