本文主要讲述最简单的线性回归函数,个人理解定义一个nn.Linear就相当于定义下面的函数:
讲解上述公式在pytorch的实现,主要包括nn.Linear的源码解读和实例展示。
先看一下Linear类的实现:源码地址
Linear继承于nn.Module,内部函数主要有__init__,reset_parameters, forward和 extra_repr函数
,下面是部分源码:
def __init__(self, in_features, out_features, bias=True):
super(Linear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
@weak_script_method
def forward(self, input):
return F.linear(input, self.weight, self.bias)
def extra_repr(self):
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None
)
1.从__init__
函数中可以看出Linear中包含四个属性
2.reset_parameters(self)
参数初始化函数
在__init__中调用此函数,权重采用Xvaier initialization 初始化方式初始参数。
3.forward(self, input)
在Module的__call__函数调用此函数,使得类对象具有函数调用的功能,同过此功能实现pytorch的网络结构堆叠。
import torch
a = torch.randn(60, 30) # 输入的维度是(60,30)
b = torch.nn.Linear(30, 15) # 输入的维度是(30,15)
output = b(a)
print('b.weight.shape:\n ', b.weight.shape)
print('b.bias.shape:\n', b.bias.shape)
print('output.shape:\n', output.shape)
# ans = torch.mm(input,torch.t(m.weight))+m.bias 等价于下面的
ans = torch.mm(a, b.weight.t()) + b.bias
print('ans.shape:\n', ans.shape)
print(torch.equal(ans, output))
结果如下:
b.weight.shape:
torch.Size([15, 30])
b.bias.shape:
torch.Size([15])
output.shape:
torch.Size([60, 15])
ans.shape:
torch.Size([60, 15])
True
因为经过了上述公式的转置,所以b.weight.shape = (15,30),这样才方便进行计算。
如有不对,多多交流。参考