keras 自定义层时重写get_config方法

前提:

在将模型保存为HD5F时。
keras 自定义层时重写get_config方法_第1张图片

动机:

  1. 如果不重写get_config,将无法在Tensorboard中载入模型图(model graph)
  2. 无法使用model.save保存模型

The base class get_config method actually refuses to run if the subclass initializer has positional arguments;

做法

在自定义层中,重写get_config方法,将位置参数以字典方式传入。

例子:

class Linear(keras.layers.Layer):
    def __init__(self, units=32, **kwargs):
        super(Linear, self).__init__(**kwargs)
        self.units = units

    def build(self, input_shape):
        self.w = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer="random_normal",
            trainable=True,
        )
        self.b = self.add_weight(
            shape=(self.units,), initializer="random_normal", trainable=True
        )

    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b

    def get_config(self):
        config = super(Linear, self).get_config()
        config.update({"units": self.units})
        return config


layer = Linear(64)
config = layer.get_config()
print(config)
new_layer = Linear.from_config(config)

参考:

  1. Keras layers should override get_config to be JSON-serializable
  2. You can optionally enable serialization on your layers
  3. Can’t save custom subclassed model
  4. keras model subclassing examples
  5. Save the entire model
  6. NotImplementedError: Learning rate schedule must override get_config

你可能感兴趣的:(#,tensorflow实战,tensorflow,深度学习)