Tensorflow2.0 定义模型的三种方法

1、API

通过直接使用 tf.keras.Sequential() 函数可以轻松地构建网络,如:

mobile = tf.keras.applications.MobileNetV2(include_top=False, weights='imagenet', input_shape=(224, 224, 3)) 
mobile.trainable = False
model = tf.keras.Sequential([
  simplified_mobile,
  tf.keras.layers.Dropout(0.5),
  tf.keras.layers.GlobalAveragePooling2D(),
  tf.keras.layers.Dense(28, activation='softmax')
])

但是,通过 API 定义的方法并不容易自定义复杂的网络。

2、通过函数定义

mobile = tf.keras.applications.MobileNetV2(include_top=False, weights='imagenet', input_shape=(224, 224, 3)) 
mobile.trainable = False

def MobileNetV2 (classes):
    img_input = tf.keras.layers.Input(shape=(224, 224, 3))

    x = mobile(img_input)
    x = tf.keras.layers.Dropout(0.5)(x)
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    x = tf.keras.layers.Dense(classes, activation='softmax')(x)

    model = tf.keras.Model(img_input, x)

    return model

使用函数定义网络时要注意以下几点:

  • 1、一个网络中往往包含了多个自定义网络函数,如卷积-批归一化-激活函数等,但在最后构建网络的函数的开头必须定义输入该网络的形状,而不是直接在函数名后面定义一个输入。当然,对前面的函数来说是可以直接多定义一个输入的;
  • 2、在函数结尾必须有:model = tf.keras.Model(img_input, x)。

3、通过类定义

mobile = tf.keras.applications.MobileNetV2(include_top=False, weights='imagenet', input_shape=(224, 224, 3)) 
mobile.trainable = False

class MobileNetV2(tf.keras.Model):
    def __init__(self, classes):
        super().__init__()
       
        self.mob = mobile
                
        self.dropout = tf.keras.layers.Dropout(0.5)
        self.gap = tf.keras.layers.GlobalAveragePooling2D()
        self.dense = tf.keras.layers.Dense(classes, activation='softmax')
        
    def call(self, inputs):

        x = self.mob(inputs)

        x = self.dropout(x)
        x = self.gap(x)
        x = self.dense(x)
        
        return x

你可能感兴趣的:(Tensorflow,2.0,深度学习,tensorflow,深度学习,神经网络)