Pytorch - nn.Linear

Ctrl并点击函数,可以看到nn.Linear源码:

class Linear(Module):
    def __init__(self, in_features, out_features, bias=True):
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    @weak_script_method
    def forward(self, input):
        return F.linear(input, self.weight, self.bias)

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )

nn.Linear继承于nn.Module,内部函数主要有__init__reset_parameters, forwardextra_repr函数。

__init__(self, in_features, out_features, bias=True)
in_features:前一层网络神经元的个数
out_features: 该网络层神经元的个数

注释:
Applies a linear transformation to the incoming data,
math:y = xA^T + b

Args:
in_features: size of each input sample
out_features: size of each output sample
bias: If set to False, the layer will not learn an additive bias.
Default: True

Attributes:(nn.linear参数)
weight, bias

Examples:
>>> m = nn.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])

import torch
x = torch.randn(128, 20)  # 输入的维度是(128,20)
m = torch.nn.Linear(20, 30) 
output = m(x)
print('m.weight.shape:\n ', m.weight.shape)
print('m.bias.shape:\n', m.bias.shape)
print('output.shape:\n', output.shape)  
print('ans.shape:\n', ans.shape)
print(torch.equal(ans, output))

Pytorch - nn.Linear_第1张图片
在这里插入图片描述
nn.Linear(20, 30) :
x的维度是输入维度:(128,20)
w的维度(公式中相当于A)是:(30,20)
b的维度是30
输出维度是:(128,30)

参考:
[1] pytorch系列 —5以 linear_regression为例讲解神经网络实现基本步骤以及解读nn.Linear函数:https://blog.csdn.net/dss_dssssd/article/details/83892824
[2] torch.nn.Linear()函数的理解:https://blog.csdn.net/m0_37586991/article/details/87861418

你可能感兴趣的:(Python)