物体分类(二)VGGNet

前言

VGG-Net是由牛津大学VGG(Visual Geometry Group)提出,是2014年ImageNet竞赛定位任务的第一名和分类任务的第二名的中的基础网络。VGG可以看成是加深版本的AlexNet,都是Conv layer + Pooling layer + FC layer,它主要的贡献是展示出网络的深度(depth)是算法优良性能的关键部分,并且小卷积核表现出了更好的效果。虽然现在ResNet等网络表现出了更好的效果,但是VGG仍然作为许多经典网络特征提取的核心。

论文地址:https://arxiv.org/abs/1409.1556

网络结构

如图1,是一张常见的VGG结构图,输入大小为224*224的RGB图像,预处理(preprocession)是计算出三个通道的平均值,在每个像素上减去平均值(处理后迭代更少,更快收敛)。

物体分类(二)VGGNet_第1张图片 图1 VGG结构图

VGG提出了相对AlexNet更深的网络模型,并且通过实验发现网络越深性能越好(在一定范围内)。VGG使用了更小的卷积核(3x3),使得参数更少,并且非线性表达能力更强,stride为1,所有卷积层都使用ReLU作为激活函数;同时不单单的使用卷积层,而是组合成了“卷积组”,即一个卷积组包括2~4个3x3卷积层(a stack of 3x3 conv)(表1);有的层也有1x1卷积层,加大了对非线性的拟合能力,因此网络更深。网络使用2x2的max pooling,VGG的卷积层都是same的卷积,即卷积过后的输出图像的尺寸与输入是一致的,它的下采样完全是由max pooling来实现。另外VGGNet卷积层有一个显著的特点:特征图的空间分辨率单调递减,特征图的通道数单调递增,这是为了更好地将HxWx3的图像转换为1x1xC的输出,之后的GoogLeNet与Resnet都是如此。另外,VGG还删除了Alexnet中的LRN,实验表明LRN并没有提升网络的性能,却导致更多的内存消耗和计算时间。另外表1后面4个VGG训练时参数都是通过pre-trained 网络A进行初始赋值,以便加快训练效率。图较为流行的是VGG-16和VGG-19。

物体分类(二)VGGNet_第2张图片 表1 VGG模型

实现

下面是以前用keras实现的VGG16网络,最后一层改为了四分类问题。训练数据放在train目录下即可。另外,keras中已经封装了VGG实现,所以也可以直接import VGG16来调用网络。

# -*- coding: utf-8 -*-

from keras.models import Sequential
from keras.layers.convolutional import Conv2D
from keras.layers.convolutional import MaxPooling2D
from keras.layers.core import Dense, Dropout, Flatten
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import ModelCheckpoint
from keras import optimizers

def vgg16():
    model = Sequential()
    model.add(Conv2D(64, (3, 3), strides=(1, 1), input_shape=(224, 224, 3), padding='same', activation='relu'))
    model.add(Conv2D(64, (3, 3), strides=(1, 1), padding='same', activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2), strides=2))
    model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu'))
    model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2), strides=2))
    model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))
    model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))
    model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2), strides=2))
    model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))
    model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))
    model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2), strides=2))
    model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))
    model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))
    model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2), strides=2))
    model.add(Flatten())
    model.add(Dense(4096, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(4096, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(4, activation='softmax'))

    return model

train_datagen = ImageDataGenerator(
        rescale=1./255,
        zoom_range=0.5,
        width_shift_range=0.2,
        height_shift_range=0.2)

test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory('./data/train',target_size=(224,224),shuffle=True,
                                                    batch_size=64,class_mode="categorical")
validation_generator = test_datagen.flow_from_directory('./data/validation',target_size=(224,224),batch_size=64,class_mode="categorical")

opt = optimizers.Adam(lr=0.0001)
model = vgg16()
model.compile(optimizer=opt, loss='categorical_crossentropy',metrics=['accuracy'])

checkpointer = ModelCheckpoint(filepath="./result/checkpoint-{epoch:05d}e-val_acc_{val_acc:.2f}.hdf5",
                               verbose=1, save_best_only=True, save_weights_only=True,period=2)

model.fit_generator(
        train_generator,
        steps_per_epoch=100,
        epochs=10000,
        verbose=1,
        validation_data=validation_generator,
        validation_steps=10,
        callbacks=[checkpointer])

 

参考资料

http://cs231n.github.io/convolutional-networks/

你可能感兴趣的:(深度学习)