Pytorch学习笔记——fan_in和fan_out

文章目录

  • 前言
  • fan_in和fan_out的含义是什么呢?
  • Pytorch中如何计算fan_int和fan_out?
    • 方法一
    • 方法二
  • 小结

前言

在前一篇文章中介绍了kaiming均匀初始化方法,其中有一个mode参数,可以是fan_infan_out,本文研究一下这两个参数的含义。

官方关于这个参数作用的解释:
选择"fan_in"可以保留前向计算中权重方差的大小。
选择"fan_out"将保留后向传播的方差大小。

#kaiming均匀初始化
torch.nn.init.kaiming_uniform_(
    tensor, 
    a=0, 
    mode='fan_in', 
    nonlinearity='leaky_relu'
)

fan_in和fan_out的含义是什么呢?

为了理解fan_infan_out,先要理解什么是fan

词典中的解释是这样的:

disperse or radiate from a central point to cover a wide area:
从一个中心点分散或辐射到一个大范围。

看一个例句:

“the arriving passengers began to fan out through the town in search of lodgings”
到达的乘客开始在城里散开,寻找住处。

既然fan代表从一个中心点分散到一个大范围,那么在神经网络中,这个中心点对应什么呢?

一种简单的理解方式是,这个中心点对应一个层

比如下面这个简单的Linear模型。

Pytorch学习笔记——fan_in和fan_out_第1张图片

上图就代表了一个层,对应一个nn.Linear,包含一个权重矩阵。

关于层。
有人按照绿色节点来算层,有人用中间的权重来算层,其实只是不同的视角,本质都是一样的。
我们观察Linear的实现,里面只有一个权重,并没有所谓节点。
最终训练得到的模型,也只是保留各层的权重。
但在画网络图的时候,节点又是重要的组成部分。

把层作为fan对应的中心点,那么:
这个层的输入就是fan_in
这个层的输出就是fan_out

如果从节点的角度来理解fan_in和fan_out:

fan_in:本层每个节点的输入个数。

答案是4个,因为上一层有4个节点。

fan_out:就是本层的输出节点个数(也就是本层节点数)

答案是6,因为本层有6个节点。

Pytorch中如何计算fan_int和fan_out?

仍然用一个线性层为例来说明。

方法一

使用_calculate_fan_in_and_fan_out同时计算fan_infan_out

>>> linear = nn.Linear(4, 6)
>
>>> linear
Linear(
    in_features=4, 
    out_features=6,
    bias=True)

>>> w = linear.weight
>>> w.shape
torch.Size([6, 4])
# 注意,这里weight的形状
# 对于矩阵左乘,Wx
# 将4维输入x映射为6维输出
# 所以w是 6 * 4
# 如果是矩阵右乘,则相反

>>> print(
... nn.init._calculate_fan_in_and_fan_out(
... linear.weight)
... )
(4, 6)
# fan_in: 4
# fan_out: 6

方法二

单独计算fan_infan_out

>>> w = linear.weight
>>> w.shape
torch.Size([6, 4])

>>> torch.nn.init._calculate_correct_fan(
... w,
... mode='fan_in')
4

从这个简单的例子可以看到,对于Linear模型,fan_in就是输入size,fan_out就是输出size。

小结

本文侧重于理解fan_in和fan_out的含义,通过对fan的原始含义进行分析,发现用层对应fan的中心点是一种理解方式。


文章修改记录
2022.04.06:修改网络层的说明。

你可能感兴趣的:(深度学习,pytorch,神经网络)