对pytorch中nn.Linear()的理解

本文主要讲述最简单的线性回归函数,个人理解定义一个nn.Linear就相当于定义下面的函数:
在这里插入图片描述
讲解上述公式在pytorch的实现,主要包括nn.Linear的源码解读实例展示

1. nn.Linear 源码解读

先看一下Linear类的实现:源码地址
Linear继承于nn.Module,内部函数主要有__init__,reset_parameters, forward和 extra_repr函数,下面是部分源码:

 def __init__(self, in_features, out_features, bias=True):
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    @weak_script_method
    def forward(self, input):
        return F.linear(input, self.weight, self.bias)

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )

1.从__init__函数中可以看出Linear中包含四个属性

  • in_features: 上层神经元个数
  • out_features: 本层神经元个数
  • weight:权重, 形状[out_features , in_features]
  • bias: 偏置, 形状[out_features]

2.reset_parameters(self)
参数初始化函数
在__init__中调用此函数,权重采用Xvaier initialization 初始化方式初始参数。

3.forward(self, input)
在Module的__call__函数调用此函数,使得类对象具有函数调用的功能,同过此功能实现pytorch的网络结构堆叠。

2. 实例展示

import torch
a = torch.randn(60, 30)  # 输入的维度是(60,30)
b = torch.nn.Linear(30, 15)  # 输入的维度是(30,15)
output = b(a)
print('b.weight.shape:\n ', b.weight.shape)
print('b.bias.shape:\n', b.bias.shape)
print('output.shape:\n', output.shape)

# ans = torch.mm(input,torch.t(m.weight))+m.bias 等价于下面的
ans = torch.mm(a, b.weight.t()) + b.bias
print('ans.shape:\n', ans.shape)

print(torch.equal(ans, output))

结果如下:

b.weight.shape:
  torch.Size([15, 30])
b.bias.shape:
 torch.Size([15])
output.shape:
 torch.Size([60, 15])
ans.shape:
 torch.Size([60, 15])
True

因为经过了上述公式的转置,所以b.weight.shape = (15,30),这样才方便进行计算。

如有不对,多多交流。参考

你可能感兴趣的:(计算机视觉)