(三)深度学习计算 -- 2

2.模型参数的访问、初始化和共享

import tensorflow as tf

2.1 access parameters

先定义一个含单隐藏层的多层感知机,用默认方式进行参数初始化,并做一次前向计算:

# 构造
net = tf.keras.models.Sequential()
net.add(tf.keras.layers.Flatten())
net.add(tf.keras.layers.Dense(units=256,activation=tf.nn.relu))
net.add(tf.keras.layers.Dense(units=10))
X = tf.random.normal(shape=(2,20))
Y = net(X)
Y

输出:



对于使用Sequential类构造的神经网络,可以通过weights属性访问网络任一层的权重。

访问多层感知机net中隐藏层的所有参数:

net.weights[0], type(net.weights[0])

输出:

(,
 tensorflow.python.ops.resource_variable_ops.ResourceVariable)

其中,索引0表示隐藏层的参数。


2.2 initialize params

将权重参数初始化成均值为0、标准差为0.01的正态分布随机数,并将偏差参数清零:

class Linear(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.d1 = tf.keras.layers.Dense(
            units=10,
            activation=None,
            kernel_initializer=tf.random_normal_initializer(mean=0,stddev=0.01),
            bias_initializer=tf.zeros_initializer()
        )
        self.d2 = tf.keras.layers.Dense(
            units=1,
            activation=None,
            kernel_initializer=tf.ones_initializer(),
            bias_initializer=tf.ones_initializer()
        )

    def call(self, input):
        output = self.d1(input)
        output = self.d2(output)
        return output
X = tf.random.normal(shape=(2,20))
net = Linear()
Y = net(X)
Y

输出:


net.get_weights()

输出:

[array([[ 1.08196139e-02,  5.79477428e-03, -1.06005976e-02,
          8.26828275e-03,  5.04667917e-03,  1.72793642e-02,
         -3.92814586e-03, -2.72732018e-03,  5.62155573e-03,
         -3.02137225e-04],
        [ 3.87676549e-03,  1.17851729e-02,  1.08660981e-02,
         -6.99408259e-03,  1.92590989e-03, -1.83490932e-03,
          6.89283572e-03, -7.06359418e-03,  3.14420962e-04,
          6.61447644e-04],
        [ 8.02275259e-03,  1.57884490e-02,  2.26609013e-03,
         -2.23289663e-03,  9.47015267e-03, -7.17470248e-04,
          1.16258478e-02,  8.79523432e-05,  3.11152334e-03,
         -1.19256079e-02],
        [ 6.94830623e-03,  1.22911343e-03,  2.33568228e-03,
          1.12228524e-02,  2.49789469e-02, -1.10885501e-02,
         -1.73985644e-03,  5.82345063e-03,  1.37769533e-02,
          2.61171046e-03],
        [-8.19462445e-03, -4.04411508e-03,  2.38322895e-02,
         -4.24177560e-04,  1.55501980e-02,  2.83632631e-04,
          1.01489350e-02,  1.03493668e-02,  3.76219791e-03,
          7.25238118e-03],
        [ 7.48372963e-03,  1.34342210e-02, -1.16917929e-02,
         -7.98194390e-03,  3.77418124e-03,  5.94647368e-04,
         -5.43513242e-03, -2.17209738e-02,  2.01016515e-02,
          8.55223276e-04],
        [ 6.30497979e-03, -1.42902685e-02, -5.65174408e-03,
          1.05804554e-03, -1.00298449e-02, -1.40983220e-02,
          9.58261453e-03,  1.90745108e-02, -1.15314289e-03,
         -1.60493851e-02],
        [ 1.65592954e-02, -6.49799331e-05,  7.29267858e-03,
         -1.64555088e-02,  7.86533114e-04,  9.04201809e-03,
          9.76761524e-03, -8.34710896e-03, -2.54614046e-03,
         -1.39755512e-05],
        [ 4.41787904e-03, -1.37950480e-02,  3.53181711e-03,
          1.53799134e-03, -1.65310241e-02,  8.04246869e-03,
         -1.56394043e-03,  4.09922330e-03, -6.75616204e-04,
          3.64435394e-03],
        [ 2.04873364e-03, -5.06615173e-03, -1.34313554e-02,
          1.44438725e-02, -2.43533379e-03, -2.76915147e-03,
          8.29426106e-03,  2.37593781e-02, -1.67158642e-03,
         -9.71452612e-03],
        [-6.51521317e-04, -2.00210139e-02,  4.81219077e-03,
          1.39641659e-02, -9.27464478e-03,  3.03924573e-03,
          8.56531877e-03,  7.31856516e-03, -1.51513061e-02,
          1.61066949e-02],
        [ 2.83615035e-03,  8.50357523e-04,  1.51638659e-02,
          9.69374366e-03, -8.76047241e-04,  8.99826828e-03,
         -5.30649349e-03, -4.62086685e-03,  1.37379104e-02,
         -1.76256541e-02],
        [-1.42976148e-02,  9.24051553e-03,  3.62900808e-03,
          3.00289178e-03,  5.01110684e-03,  1.03468634e-03,
          1.01388549e-03,  9.39975400e-03, -6.78061135e-03,
         -1.49735911e-02],
        [ 8.51881705e-05, -4.43281513e-03,  3.90838506e-03,
         -2.92743999e-03,  2.06509116e-03, -6.18821578e-05,
          8.80562980e-03,  2.54227780e-03,  2.56423809e-05,
         -4.66403132e-03],
        [ 8.15627631e-03,  7.13415863e-03, -2.01360248e-02,
          4.30102786e-03, -4.99079004e-04,  5.28379949e-03,
          1.58037301e-02, -3.75230238e-03, -8.95734876e-04,
          1.39328260e-02],
        [-2.48550382e-02, -6.31160010e-03, -6.22724183e-03,
         -9.44277551e-03,  1.10578630e-02, -6.14645949e-04,
         -2.52730539e-03,  2.29452504e-03,  9.93179274e-04,
         -2.90897046e-03],
        [ 6.50542649e-03, -1.21802492e-02, -1.15979975e-02,
          3.04045796e-04, -5.20680938e-03, -5.98686142e-03,
          1.02923429e-02, -2.00138614e-02,  2.01741327e-02,
          1.01022851e-02],
        [-4.86079371e-03,  3.91012290e-03,  1.68614404e-03,
         -7.48395314e-03, -5.46253286e-04, -8.52062460e-03,
          7.88574014e-03,  1.17413038e-02, -4.50282078e-03,
          4.58895974e-03],
        [ 7.92449064e-05,  3.17677343e-03, -1.36370957e-02,
         -1.07877040e-02,  1.27669917e-02,  9.87697672e-03,
         -4.51020384e-03, -9.63695813e-03,  5.63280517e-03,
         -2.42555770e-03],
        [ 1.30282128e-02, -1.77450315e-03, -9.25759878e-03,
         -9.80420282e-06, -3.06460354e-03, -1.73859636e-03,
          7.01046176e-03, -1.29886828e-02, -6.16477896e-03,
          4.15576249e-03]], dtype=float32),
 array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 array([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]], dtype=float32),
 array([1.], dtype=float32)]

2.3 define initializer

使用tf.keras.initializers类中的方法实现自定义初始化:

def my_init():
    """Initializer that generates tensors initialized to 1"""
    return tf.keras.initializers.Ones()

model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(units=64,kernel_initializer=my_init()))
X = tf.random.normal(shape=(2,20))
Y = model(X)
Y

输出:


model.weights[0]

输出:




参考

《动手学深度学习》(TF2.0版)

你可能感兴趣的:(动手学深度学习)