tensorflow2中自定义层或网络

通过自定义网络可以将一些现有的网络和自己的网络串联起来,这样可以组建各种各样的网络。

1、keras.Sequential

它的用法见这篇博客

network = Sequential([layers.Dense(256, activation='relu'),
                     layers.Dense(128, activation='relu'),
                     layers.Dense(64, activation='relu'),
                     layers.Dense(32, activation='relu'),
                     layers.Dense(10)])
network.build(input_shape=(None, 28*28))
network.summary()

2、自定义层或网络

(1)、实现自定义网络层

在tensorflow 2中自定义网络层最好的方法就是继承tf.keras.layers.Layer类,并且根据实际情况重写如下几个类方法:

  • __init__: 初始化类,你可以在此配置一些网络层需要的参数,并且也可以在此实例化tf.keras提供的一些基础算子比如DepthwiseConv2D,方便在call方法中应用;
  • build:该方法可以获取输入的shape维度,方便动态构建某些需要的算子比如Pool或基于input shape构建权重;
  • call: 网络层进行前向推理的实现方法;

实现自定义层示例:

class MyDense(tf.keras.layers.Layer):
  def __init__(self, inp_dim, out_dim):
    super(MyDense, self).__init__()

    self.kernel = self.add_variable('w', [inp_dim, outp_dim])
    self.bias = self.add_variable('b', [outp_dim])

  def call(self, inputs, training = None):
    out = inputs @ self.kernel + self.bias
    return out

(2) 实现自定义模型Model

要实现自定义模型,我们的模型类需要先继承keras.layers.Model类,这样我们的模型类就能使用母类中的许多方法,例如:compile、fit、 evaluate等。

然后,在模型类中实现 __init__()call() 这两个方法。其原理和自定义层是一样的。

实现自定义模型Model示例:

class MyModel(tf.keras.Model):
  def __init__(self):
    super(MyModel, self).__init__()
    self.fc1 = MyDense(28*28, 256)
    self.fc2 = MyDense(256, 128)
    self.fc3 = MyDense(128, 64)
    self.fc4 = MyDense(64, 32)
    self.fc5 = MyDense(32, 10)

  def call(self, inputs, training = None):
    x = self.fc1(inputs)
    x = tf.nn.relu(x)
    x = self.fc2(x)
    x = tf.nn.relu(x)
    x = self.fc3(x)
    x = tf.nn.relu(x)
    x = self.fc4(x)
    x = tf.nn.relu(x)
    x = self.fc5(x)
    return x

3、自定义层或网络实战(前向传播)

import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
from tensorflow import keras


def preprocess(x, y):
    """
    x is a simple image, not a batch
    """
    x = tf.cast(x, dtype=tf.float32) / 255.
    x = tf.reshape(x, [28 * 28])
    y = tf.cast(y, dtype=tf.int32)
    y = tf.one_hot(y, depth=10)
    return x, y


batchsz = 128
(x, y), (x_val, y_val) = datasets.mnist.load_data()
print('datasets:', x.shape, y.shape, x.min(), x.max())

db = tf.data.Dataset.from_tensor_slices((x, y))
db = db.map(preprocess).shuffle(60000).batch(batchsz)
ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
ds_val = ds_val.map(preprocess).batch(batchsz)

sample = next(iter(db))
print(sample[0].shape, sample[1].shape)

network = Sequential([layers.Dense(256, activation='relu'),
                      layers.Dense(128, activation='relu'),
                      layers.Dense(64, activation='relu'),
                      layers.Dense(32, activation='relu'),
                      layers.Dense(10)])
network.build(input_shape=(None, 28 * 28))
network.summary()


class MyDense(layers.Layer):

    def __init__(self, inp_dim, outp_dim):
        super(MyDense, self).__init__()

        self.kernel = self.add_variable('w', [inp_dim, outp_dim])
        self.bias = self.add_variable('b', [outp_dim])

    def call(self, inputs, training=None):
        out = inputs @ self.kernel + self.bias

        return out


class MyModel(keras.Model):

    def __init__(self):
        super(MyModel, self).__init__()

        self.fc1 = MyDense(28 * 28, 256)
        self.fc2 = MyDense(256, 128)
        self.fc3 = MyDense(128, 64)
        self.fc4 = MyDense(64, 32)
        self.fc5 = MyDense(32, 10)

    def call(self, inputs, training=None):
        x = self.fc1(inputs)
        x = tf.nn.relu(x)
        x = self.fc2(x)
        x = tf.nn.relu(x)
        x = self.fc3(x)
        x = tf.nn.relu(x)
        x = self.fc4(x)
        x = tf.nn.relu(x)
        x = self.fc5(x)

        return x


network = MyModel()

network.compile(optimizer=optimizers.Adam(lr=0.01),
                loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                metrics=['accuracy']
                )

network.fit(db, epochs=5, validation_data=ds_val,
            validation_freq=2)

network.evaluate(ds_val)

sample = next(iter(ds_val))
x = sample[0]
y = sample[1]  # one-hot
pred = network.predict(x)  # [b, 10]
# convert back to number
y = tf.argmax(y, axis=1)
pred = tf.argmax(pred, axis=1)

print(pred)
print(y)

你可能感兴趣的:(TensorFlow2.×)