Tensorflow2.0入门教程13:自定义层、损失函数和评估指标

如果现有的这些层无法满足我的要求,我需要定义自己的层怎么办?我们不仅可以继承 tf.keras.Model 编写自己的模型类,也可以继承 tf.keras.layers.Layer 编写自己的层。

自定义层

自定义层需要继承 tf.keras.layers.Layer 类,并重写 init 、 build 和 call 三个方法,如下所示:

import tensorflow as tf
class MyLayer(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
        # 初始化代码

    def build(self, input_shape):     
        # input_shape 是一个 TensorShape 类型对象,提供输入的形状
        # 在第一次使用该层的时候调用该部分代码,在这里创建变量可以使得变量的形状自适应输入的形状
        # 而不需要使用者额外指定变量形状。
        # 如果已经可以完全确定变量的形状,也可以在__init__部分创建变量
        self.variable_0 = self.add_weight(...)
        self.variable_1 = self.add_weight(...)

    def call(self, inputs):
        # 模型调用的代码(处理输入并返回输出)
        return output

例如,如果我们要自己实现一个全连接层( tf.keras.layers.Dense ),可以按如下方式编写。此代码在 build 方法中创建两个变量,并在 call 方法中使用创建的变量进行运算:

add_variable:创建一个可训练的权重

add_weight:创建一个可训练的权重

class LinearLayer(tf.keras.layers.Layer):
    def __init__(self, units):
        super().__init__()
        self.units = units

    def build(self, input_shape):     # 这里 input_shape 是第一次运行call()时参数inputs的形状
        self.w = self.add_variable(name='w',
            shape=[input_shape[-1], self.units], initializer=tf.zeros_initializer())
        self.b = self.add_variable(name='b',
            shape=[self.units], initializer=tf.zeros_initializer())

    def call(self, inputs):
        y_pred = tf.matmul(inputs,self.w) + self.b # y = x*w+b
        return y_pred

在定义模型的时候,我们便可以如同 Keras 中的其他层一样,调用我们自定义的层 LinearLayer:

class LinearModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.layer1 = LinearLayer(units=1)

    def call(self, inputs):
        output = self.layer1(inputs)
        return output

自定义损失函数和评估指标

自定义损失函数需要继承 tf.keras.losses.Loss 类,重写 call 方法即可,输入真实值 y_true 和模型预测值 y_pred ,输出模型预测值和真实值之间通过自定义的损失函数计算出的损失值。下面的示例为均方差损失函数:

class MeanSquaredError(tf.keras.losses.Loss):
    def call(self, y_true, y_pred):
        return tf.reduce_mean(tf.square(y_pred - y_true))

自定义评估指标需要继承 tf.keras.metrics.Metric 类,并重写 init 、 update_state 和 result 三个方法。下面的示例对前面用到的 SparseCategoricalAccuracy 评估指标类做了一个简单的重实现:

tf.assign_add(ref, value):更新值,通过增加value,例如:ref = ref + value

class SparseCategoricalAccuracy(tf.keras.metrics.Metric):
    def __init__(self):
        super().__init__()
        self.total = self.add_weight(name='total', dtype=tf.int32, initializer=tf.zeros_initializer())
        self.count = self.add_weight(name='count', dtype=tf.int32, initializer=tf.zeros_initializer())

    def update_state(self, y_true, y_pred, sample_weight=None):
        values = tf.cast(tf.equal(y_true, tf.argmax(y_pred, axis=-1, output_type=tf.int32)), tf.int32)
        # self.total = self.total+tf.shape(y_true)[0]
        self.total.assign_add(tf.shape(y_true)[0])
        self.count.assign_add(tf.reduce_sum(values))

    def result(self):
        return self.count / self.total

使用自定义层,损失函数训练模型

model = LinearModel()
mse = MeanSquaredError()
# 3. 定义训练参数
model.compile(
    optimizer=tf.keras.optimizers.SGD(learning_rate=0.05),     # 指定优化器
    loss=mse,   # 指定损失函数
)
x = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
y = tf.constant([5.0,6.0,7.0,8.0,9.0,10.0])
model.fit(x,y,epochs=100)
Train on 6 samples
Epoch 1/100
6/6 [==============================] - 0s 18ms/sample - loss: 59.1667
Epoch 2/100
6/6 [==============================] - 0s 167us/sample - loss: 22.6748
Epoch 3/100
6/6 [==============================] - 0s 332us/sample - loss: 9.5530
Epoch 4/100
6/6 [==============================] - 0s 166us/sample - loss: 4.8023
Epoch 5/100
6/6 [==============================] - 0s 166us/sample - loss: 3.0510
Epoch 6/100
6/6 [==============================] - 0s 166us/sample - loss: 2.3757
Epoch 7/100
6/6 [==============================] - 0s 169us/sample - loss: 2.0875
Epoch 8/100
6/6 [==============================] - 0s 330us/sample - loss: 1.9396
Epoch 9/100
6/6 [==============================] - 0s 333us/sample - loss: 1.8435
Epoch 10/100
6/6 [==============================] - 0s 332us/sample - loss: 1.7676
Epoch 11/100
6/6 [==============================] - 0s 333us/sample - loss: 1.7003
Epoch 12/100
6/6 [==============================] - 0s 166us/sample - loss: 1.6376
Epoch 13/100
6/6 [==============================] - 0s 167us/sample - loss: 1.5780
Epoch 14/100
6/6 [==============================] - 0s 332us/sample - loss: 1.5208
Epoch 15/100
6/6 [==============================] - 0s 332us/sample - loss: 1.4657
Epoch 16/100
6/6 [==============================] - 0s 166us/sample - loss: 1.4127
Epoch 17/100
6/6 [==============================] - 0s 166us/sample - loss: 1.3616
Epoch 18/100
6/6 [==============================] - 0s 332us/sample - loss: 1.3124
Epoch 19/100
6/6 [==============================] - 0s 333us/sample - loss: 1.2649
Epoch 20/100
6/6 [==============================] - 0s 171us/sample - loss: 1.2192
Epoch 21/100
6/6 [==============================] - 0s 161us/sample - loss: 1.1751
Epoch 22/100
6/6 [==============================] - 0s 161us/sample - loss: 1.1326
Epoch 23/100
6/6 [==============================] - 0s 333us/sample - loss: 1.0916
Epoch 24/100
6/6 [==============================] - 0s 166us/sample - loss: 1.0522
Epoch 25/100
6/6 [==============================] - 0s 333us/sample - loss: 1.0141
Epoch 26/100
6/6 [==============================] - 0s 332us/sample - loss: 0.9774
Epoch 27/100
6/6 [==============================] - 0s 502us/sample - loss: 0.9421
Epoch 28/100
6/6 [==============================] - 0s 330us/sample - loss: 0.9080
Epoch 29/100
6/6 [==============================] - 0s 166us/sample - loss: 0.8752
Epoch 30/100
6/6 [==============================] - 0s 333us/sample - loss: 0.8435
Epoch 31/100
6/6 [==============================] - 0s 166us/sample - loss: 0.8130
Epoch 32/100
6/6 [==============================] - 0s 168us/sample - loss: 0.7836
Epoch 33/100
6/6 [==============================] - 0s 166us/sample - loss: 0.7553
Epoch 34/100
6/6 [==============================] - 0s 332us/sample - loss: 0.7280
Epoch 35/100
6/6 [==============================] - 0s 500us/sample - loss: 0.7017
Epoch 36/100
6/6 [==============================] - 0s 164us/sample - loss: 0.6763
Epoch 37/100
6/6 [==============================] - 0s 169us/sample - loss: 0.6518
Epoch 38/100
6/6 [==============================] - 0s 330us/sample - loss: 0.6283
Epoch 39/100
6/6 [==============================] - 0s 332us/sample - loss: 0.6055
Epoch 40/100
6/6 [==============================] - 0s 333us/sample - loss: 0.5836
Epoch 41/100
6/6 [==============================] - 0s 333us/sample - loss: 0.5625
Epoch 42/100
6/6 [==============================] - 0s 166us/sample - loss: 0.5422
Epoch 43/100
6/6 [==============================] - 0s 331us/sample - loss: 0.5226
Epoch 44/100
6/6 [==============================] - 0s 166us/sample - loss: 0.5037
Epoch 45/100
6/6 [==============================] - 0s 166us/sample - loss: 0.4855
Epoch 46/100
6/6 [==============================] - 0s 167us/sample - loss: 0.4679
Epoch 47/100
6/6 [==============================] - 0s 331us/sample - loss: 0.4510
Epoch 48/100
6/6 [==============================] - 0s 166us/sample - loss: 0.4347
Epoch 49/100
6/6 [==============================] - 0s 332us/sample - loss: 0.4190
Epoch 50/100
6/6 [==============================] - 0s 333us/sample - loss: 0.4038
Epoch 51/100
6/6 [==============================] - 0s 332us/sample - loss: 0.3892
Epoch 52/100
6/6 [==============================] - 0s 332us/sample - loss: 0.3751
Epoch 53/100
6/6 [==============================] - 0s 166us/sample - loss: 0.3616
Epoch 54/100
6/6 [==============================] - 0s 168us/sample - loss: 0.3485
Epoch 55/100
6/6 [==============================] - 0s 331us/sample - loss: 0.3359
Epoch 56/100
6/6 [==============================] - 0s 334us/sample - loss: 0.3238
Epoch 57/100
6/6 [==============================] - 0s 333us/sample - loss: 0.3120
Epoch 58/100
6/6 [==============================] - 0s 332us/sample - loss: 0.3008
Epoch 59/100
6/6 [==============================] - 0s 330us/sample - loss: 0.2899
Epoch 60/100
6/6 [==============================] - 0s 166us/sample - loss: 0.2794
Epoch 61/100
6/6 [==============================] - 0s 167us/sample - loss: 0.2693
Epoch 62/100
6/6 [==============================] - 0s 332us/sample - loss: 0.2596
Epoch 63/100
6/6 [==============================] - 0s 166us/sample - loss: 0.2502
Epoch 64/100
6/6 [==============================] - 0s 166us/sample - loss: 0.2411
Epoch 65/100
6/6 [==============================] - 0s 168us/sample - loss: 0.2324
Epoch 66/100
6/6 [==============================] - 0s 167us/sample - loss: 0.2240
Epoch 67/100
6/6 [==============================] - 0s 165us/sample - loss: 0.2159
Epoch 68/100
6/6 [==============================] - 0s 333us/sample - loss: 0.2081
Epoch 69/100
6/6 [==============================] - 0s 166us/sample - loss: 0.2006
Epoch 70/100
6/6 [==============================] - 0s 164us/sample - loss: 0.1933
Epoch 71/100
6/6 [==============================] - 0s 332us/sample - loss: 0.1863
Epoch 72/100
6/6 [==============================] - 0s 166us/sample - loss: 0.1796
Epoch 73/100
6/6 [==============================] - 0s 333us/sample - loss: 0.1731
Epoch 74/100
6/6 [==============================] - 0s 167us/sample - loss: 0.1668
Epoch 75/100
6/6 [==============================] - 0s 333us/sample - loss: 0.1608
Epoch 76/100
6/6 [==============================] - 0s 333us/sample - loss: 0.1550
Epoch 77/100
6/6 [==============================] - 0s 169us/sample - loss: 0.1494
Epoch 78/100
6/6 [==============================] - 0s 333us/sample - loss: 0.1440
Epoch 79/100
6/6 [==============================] - 0s 166us/sample - loss: 0.1388
Epoch 80/100
6/6 [==============================] - 0s 332us/sample - loss: 0.1338
Epoch 81/100
6/6 [==============================] - 0s 166us/sample - loss: 0.1289
Epoch 82/100
6/6 [==============================] - 0s 166us/sample - loss: 0.1243
Epoch 83/100
6/6 [==============================] - 0s 333us/sample - loss: 0.1198
Epoch 84/100
6/6 [==============================] - 0s 167us/sample - loss: 0.1154
Epoch 85/100
6/6 [==============================] - 0s 166us/sample - loss: 0.1113
Epoch 86/100
6/6 [==============================] - 0s 166us/sample - loss: 0.1072
Epoch 87/100
6/6 [==============================] - 0s 498us/sample - loss: 0.1034
Epoch 88/100
6/6 [==============================] - 0s 166us/sample - loss: 0.0996
Epoch 89/100
6/6 [==============================] - 0s 169us/sample - loss: 0.0960
Epoch 90/100
6/6 [==============================] - 0s 163us/sample - loss: 0.0925
Epoch 91/100
6/6 [==============================] - 0s 333us/sample - loss: 0.0892
Epoch 92/100
6/6 [==============================] - 0s 166us/sample - loss: 0.0860
Epoch 93/100
6/6 [==============================] - 0s 169us/sample - loss: 0.0829
Epoch 94/100
6/6 [==============================] - 0s 330us/sample - loss: 0.0799
Epoch 95/100
6/6 [==============================] - 0s 334us/sample - loss: 0.0770
Epoch 96/100
6/6 [==============================] - 0s 332us/sample - loss: 0.0742
Epoch 97/100
6/6 [==============================] - 0s 333us/sample - loss: 0.0715
Epoch 98/100
6/6 [==============================] - 0s 166us/sample - loss: 0.0689
Epoch 99/100
6/6 [==============================] - 0s 332us/sample - loss: 0.0664
Epoch 100/100
6/6 [==============================] - 0s 166us/sample - loss: 0.0640


model.evaluate(x,y,verbose=2)
6/1 - 0s - loss: 0.0617
0.061718087643384933

你可能感兴趣的:(Tensorflow2.0)