深度学习笔记--使用keras创建和加载VGG模型

VGG16模型,顾名思义,有16层,通过学习github上的源码,发现其中有13层为卷积层(conv),3层为全连接(Dense),还有若干层pooling层。
使用keras建立一个模型的instance有两种方式,一个是通过Input类构建,一个则是通过model类构建,具体的形式有所不同,本文采用Input类构建,这里默认输入图片大小为(50,50,3)。使用代码可以清晰看出VGG的每一层参数。

这里加载模型的参数,需要下载h5文件到本地,这里主要介绍建立VGG16模型的步骤,代码主要参照keras VGG源码,地址:

https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/keras/applications/vgg16.py

import tensorflow as tf
from tensorflow import keras
from keras import Model,Sequential
from keras.layers import Flatten, Dense, Conv2D, GlobalAveragePooling2D
from keras.layers import Input, MaxPooling2D, GlobalMaxPooling2D


def VGG16(num_classes,importModel = None):

    image_input = Input(shape = (50,50,3))
    #block1
    x = Conv2D(64,(3,3),activation = 'relu',padding = 'same',name = 'block1_conv1')(image_input)
    x = Conv2D(64,(3,3),activation = 'relu',padding = 'same', name = 'block1_conv2')(x)
    x = MaxPooling2D((2,2), strides = (2,2), name = 'block1_pool')(x)
    #block2
    x = Conv2D(128,(3,3),activation = 'relu',padding = 'same',name = 'block2_conv1')(x)
    x = Conv2D(128,(3,3),activation = 'relu',padding = 'same',name = 'block2_conv2')(x)
    x = MaxPooling2D((2,2),strides = (2,2),name = 'block2_pool')(x)
    #block3
    x = Conv2D(256,(3,3),activation = 'relu',padding = 'same',name = 'block3_conv1')(x)
    x = Conv2D(256,(3,3),activation = 'relu',padding = 'same',name = 'block3_conv2')(x)
    x = Conv2D(256,(3,3),activation = 'relu',padding = 'same',name = 'block3_conv3')(x)
    x = MaxPooling2D((2,2),strides = (2,2),name = 'block3_pool')(x)
    #block4
    x = Conv2D(512,(3,3),activation = 'relu',padding = 'same', name = 'block4_conv1')(x)
    x = Conv2D(512,(3,3),activation = 'relu',padding = 'same', name = 'block4_conv2')(x)
    x = Conv2D(512,(3,3),activation = 'relu',padding = 'same', name = 'block4_conv3')(x)
    x = MaxPooling2D((2,2),strides = (2,2),name = 'block4_pool')(x)
    #block5
    x = Conv2D(512,(3,3),activation = 'relu',padding = 'same', name = 'block5_conv1')(x)
    x = Conv2D(512,(3,3),activation = 'relu',padding = 'same', name = 'block5_conv2')(x)
    x = Conv2D(512,(3,3),activation = 'relu',padding = 'same', name = 'block5_conv3')(x)    
    x = MaxPooling2D((2,2),strides = (2,2),name = 'block5_pool')(x)
    #Classification block
    x = Flatten(name = 'flatten')(x)
    x = Dense(4096,activation = 'relu',name = 'fc1')(x)
    x = Dense(4096,activation = 'relu',name = 'fc2')(x)
    x = Dense(num_classes,activation = 'softmax',name = 'fc3')(x)
    model = Model(image_input,x,name = 'vgg16')
    if importModel:
        model = Sequential()
        model.load_weights(importModel)
    return model

if __name__ == "__main__":
    newModel = VGG16(20)
    print(newModel.summary())
    #importModel = VGG16(20,'vgg16_weights.h5')
    #print(importModel.summary())

代码输出为:

Using TensorFlow backend.
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         (None, 50, 50, 3)         0
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 50, 50, 64)        1792
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 50, 50, 64)        36928
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 25, 25, 64)        0
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 25, 25, 128)       73856
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 25, 25, 128)       147584
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 12, 12, 128)       0
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 12, 12, 256)       295168
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 12, 12, 256)       590080
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 12, 12, 256)       590080
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 6, 6, 256)         0
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 6, 6, 512)         1180160
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 6, 6, 512)         2359808
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 6, 6, 512)         2359808
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 3, 3, 512)         0
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 3, 3, 512)         2359808
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 3, 3, 512)         2359808
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 3, 3, 512)         2359808
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 1, 1, 512)         0
_________________________________________________________________
flatten (Flatten)            (None, 512)               0
_________________________________________________________________
fc1 (Dense)                  (None, 4096)              2101248
_________________________________________________________________
fc2 (Dense)                  (None, 4096)              16781312
_________________________________________________________________
fc3 (Dense)                  (None, 20)                81940
=================================================================
Total params: 33,679,188
Trainable params: 33,679,188
Non-trainable params: 0
_________________________________________________________________
None

你可能感兴趣的:(深度学习笔记--使用keras创建和加载VGG模型)