这个函数就是实现这个公式
def kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
r"""kaiming正态分布
"""
fan = _calculate_correct_fan(tensor, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
with torch.no_grad():
# 这句是返回指定区间内随机生成的正太分布的值的
return tensor.normal_(0, std)
_calculate_correct_fan(tensor, mode)
是算出input和output feature map的元素总数,源码为:
def _calculate_correct_fan(tensor, mode):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
# 这里是fmap的大小
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
# 根据mode选择返回数据
return fan_in if mode == 'fan_in' else fan_out
def _calculate_fan_in_and_fan_out(tensor):
dimensions = tensor.dim()
if dimensions < 2:
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
# 这里相当于输出了前两维的size
num_input_fmaps = tensor.size(1)
num_output_fmaps = tensor.size(0)
# 这里相当于计算了后两维的元素总和
receptive_field_size = 1
if tensor.dim() > 2:
# numel()的作用就是计算元素的个数
receptive_field_size = tensor[0][0].numel()
# 然后算出in/out的fmap的大小
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
上面源码可以用下列例子解释:
感谢评论区大佬指出错误,num_input_fmaps是用的size(1),num_output_fmaps用的size(0)
calculate_gain(nonlinearity, a)
如果选的是relu,那么return math.sqrt(2.0),即根号2,下面是源码,其中注释给出了详细的gain值
def calculate_gain(nonlinearity, param=None):
r"""Return the recommended gain value for the given nonlinearity function.
The values are as follows:
================= ====================================================
nonlinearity gain
================= ====================================================
Linear / Identity :math:`1`
Conv{1,2,3}D :math:`1`
Sigmoid :math:`1`
Tanh :math:`\frac{5}{3}`
ReLU :math:`\sqrt{2}`
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
SELU :math:`\frac{3}{4}`
================= ====================================================
Args:
nonlinearity: the non-linear function (`nn.functional` name)
param: optional parameter for the non-linear function
Examples:
>>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
"""
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
return 1
elif nonlinearity == 'tanh':
return 5.0 / 3
elif nonlinearity == 'relu':
return math.sqrt(2.0)
elif nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
# True/False are instances of int, hence check above
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
return math.sqrt(2.0 / (1 + negative_slope ** 2))
elif nonlinearity == 'selu':
return 3.0 / 4 # Value found empirically (https://github.com/pytorch/pytorch/pull/50664)
else:
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
tensor.normal_(0, std)