Keras实现VGG

 

VGG的网络结构如图,16和19的网络结构的差异是每个layer的卷积核的数目、卷积的次数不同。

代码如下,这里使用了mnist数据集训练,注意每次maxpooling后,图像的分辨率是减半的,所以4次pooling之后,分辨率就是1了,第五次maxpooling就会出错。所以maxpooling之后,可以zeropadding一下;或者直接取消第五次pooling。

# from keras.models import
from keras.layers import *
from keras.models import Input, load_model, Sequential
from keras import Model
from keras.datasets import mnist
from keras.utils import to_categorical
from keras.losses import categorical_crossentropy
import keras.optimizers
import numpy as np


def vgg(input_shape, num_cls, filters_num, conv_nums):
    # print(input_shape)
    inputs = Input(shape=input_shape)
    x = inputs
    for i in range(len(conv_nums)):
        for j in range(conv_nums[i]):
            x = Conv2D(filters=filters_num[i], kernel_size=3, padding='same',
                       name='stage{0}_conv{1}'.format(i+1, j+1))(x)
        x = MaxPool2D((2, 2), strides=2, name='maxpool_'+str(i+1))(x)
        x = ZeroPadding2D((1, 1))(x)
    x = Flatten(name='flatten')(x)
    x = Dense(units=4096, name='dense4096_1')(x)
    x = Dense(units=4096, name='dense4096_2')(x)
    x = Dense(units=num_cls, name='dense1000', activation='softmax')(x)
    model = Model(inputs=inputs, outputs=x, name='vgg')
    model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['acc'])
    return model


def train(net_name):
    path = r'C:\Users\.keras\datasets\mnist.npz'
    with np.load(path, allow_pickle=True) as f:
        x_train, y_train = f['x_train'], f['y_train']
        x_test, y_test = f['x_test'], f['y_test']

    x_train = x_train.reshape(x_train.shape[0], 28, 28, 1).astype('float32')
    x_test = x_test.reshape(x_test.shape[0], 28, 28, 1).astype('float32')
    num_classes = 10
    x_train = x_train / 255.
    x_test = x_test / 255.
    y_train = to_categorical(y_train, num_classes)
    y_test = to_categorical(y_test, num_classes)

    batch_size = 16
    epochs = 1

    if net_name == 'vgg-19':
        filters_num = [64, 128, 256, 512, 512]
        conv_nums = [2, 2, 4, 4, 4]
    else:
        filters_num = [32, 64, 128, 256, 512]
        conv_nums = [2, 2, 3, 3, 3]
    vgg_model = vgg(input_shape=(28, 28, 1), num_cls=num_classes, filters_num=filters_num,
                    conv_nums=conv_nums)
    vgg_model.summary()
    vgg_model.fit(x_train, y_train, epochs=epochs, batch_size=batch_size, validation_split=0.1)
    vgg_model.save('{0}-mnist.h5'.format(net_name))
    eval_res = vgg_model.evaluate(x_test, y_test)
    print(eval_res)


if __name__ == '__main__':
    train('vgg-16')

 

你可能感兴趣的:(深度学习,Keras)