神经网络的权重初始化
flyfish 笔记
如果激活函数是tanh
w[l] = np.random.randn(n[l],n[l-1])*np.sqrt(1/n[l-1])
如果激活函数是ReLU
w[l] = np.random.randn(n[l],n[l-1])*np.sqrt(2/n[l-1])
Yoshua Bengio初始化w的方法
w[l] = np.random.randn(n[l],n[l-1])*np.sqrt(2/n[l-1]*n[l])
import numpy as np
def initialize_parameters_he(layers_dims):
np.random.seed(0)
parameters = {}
L = len(layers_dims) - 1 # 表示层数
for l in range(1, L + 1):#loop 1到L
print("l:",l)
print("layers_dims 1:",layers_dims[l])
print("layers_dims l-1:",layers_dims[l - 1])
parameters['W' + str(l)] = np.random.randn(layers_dims[l], layers_dims[l-1]) * np.sqrt(2./layers_dims[l-1])
parameters['b' + str(l)] = np.zeros((layers_dims[l], 1))
return parameters
parameters = initialize_parameters_he([2, 4, 3,8,1])
print("W1 = " + str(parameters["W1"]))
print("b1 = " + str(parameters["b1"]))
print("W2 = " + str(parameters["W2"]))
print("b2 = " + str(parameters["b2"]))
print("W3 = " + str(parameters["W3"]))
print("b3 = " + str(parameters["b3"]))
print("W4 = " + str(parameters["W4"]))
print("b4 = " + str(parameters["b4"]))
输出
l: 1
layers_dims 1: 4
layers_dims l-1: 2
l: 2
layers_dims 1: 3
layers_dims l-1: 4
l: 3
layers_dims 1: 8
layers_dims l-1: 3
l: 4
layers_dims 1: 1
layers_dims l-1: 8
W1 = [[ 1.76405235 0.40015721]
[ 0.97873798 2.2408932 ]
[ 1.86755799 -0.97727788]
[ 0.95008842 -0.15135721]]
b1 = [[0.]
[0.]
[0.]
[0.]]
W2 = [[-0.07298675 0.29033699 0.10185419 1.02832666]
[ 0.53813494 0.08603723 0.3138587 0.23594338]
[ 1.05647344 -0.1450688 0.22137229 -0.60393689]]
b2 = [[0.]
[0.]
[0.]]
W3 = [[-2.08450746 0.53367735 0.7058092 ]
[-0.6059752 1.85324689 -1.1874846 ]
[ 0.03736167 -0.15283497 1.25150899]
[ 1.19972641 0.12651404 0.3087684 ]
[-0.72487403 -1.61731354 -0.28406908]
[ 0.1276584 1.00452813 0.98173904]
[-0.31625102 -0.24682916 -0.85613991]
[-1.15943979 -1.39316378 1.59280144]]
b3 = [[0.]
[0.]
[0.]
[0.]
[0.]
[0.]
[0.]
[0.]]
W4 = [[-0.25482609 -0.21903715 -0.62639768 0.38874518 -0.80694892 -0.10637014
-0.44773328 0.19345125]]
b4 = [[0.]]
关于随机数代码的解释
import numpy as np
import matplotlib.pyplot as plt
#不设置seed,每次会生成不同的随机数
print(np.random.rand(3,2))
'''
[[0.63611338 0.81176128]
[0.64378796 0.81751788]
[0.11384994 0.79912718]]
'''
print(np.random.rand(3,2))
'''
[[0.31166034 0.59363663]
[0.45223468 0.01044228]
[0.46475564 0.74117943]]
'''
#相同的seed相同的输出
np.random.seed(0)
print(np.random.rand(5))
#[0.5488135 0.71518937 0.60276338 0.54488318 0.4236548 ]
np.random.seed(0)
print(np.random.rand(5))
#[0.5488135 0.71518937 0.60276338 0.54488318 0.4236548 ]
#正态分布 histogram
x=np.random.randn(500)
plt.hist(x,500)
plt.show()
print(np.random.randn(4,2))