Keras实现预训练网络VGG16迁移学习——102flowers分类【100行代码训练&预测】

文章目录

  • 一、简介
  • 二、训练代码
  • 三、训练结果
  • 四、预测代码
  • 五、参考项目





一、简介

  1. Oxford 102flowers数据集:牛津大学在2009发布的图像数据集。包含102种英国常见花类,每个类别包含 40-258张图像。
    论文:http://www.robots.ox.ac.uk/~vgg/publications/papers/nilsback08.pdf
    Keras实现预训练网络VGG16迁移学习——102flowers分类【100行代码训练&预测】_第1张图片




二、训练代码

  1. 原数据集包括原始图片、标签、划分,需要执行程序。此处为了方便在网盘共享已划分好的数据集
    链接:https://pan.baidu.com/s/1zEL-oji1Y9ZOMXMSXTSvgA
    提取码:isoj
import os
import glob
import math
import numpy as np
from keras import optimizers
from keras import applications
from keras.models import Model
from keras.layers import Flatten, Dense, Dropout, Input
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import EarlyStopping, ModelCheckpoint

# 数据集
train_dir = 'data/102flowers/train'  # 训练集
validation_dir = 'data/102flowers/valid'  # 验证集
nb_epoch = 50  # 迭代次数,原项目默认1000次
batch_size = 32  # 批量大小
img_size = (224, 224)  # 图片大小
freeze_layers_number = 0  # 冻结层数

classes = sorted([o for o in os.listdir(train_dir)])  # 根据文件名分类
nb_train_samples = len(glob.glob(train_dir + '/*/*.*'))  # 训练样本数
nb_validation_samples = len(glob.glob(validation_dir + '/*/*.*'))  # 验证样本数

# 定义模型
base_model = applications.VGG16(weights='imagenet', include_top=False, input_tensor=Input(shape=img_size + (3,)),
                                classes=len(classes))  # 预训练的VGG16网络,替换掉顶部网络

for layer in base_model.layers:  # 保留原有网络全部参数
    print(layer.trainable)
    layer.trainable = False

x = base_model.output  # 自定义网络
x = Flatten()(x)  # 展平
x = Dense(4096, activation='elu', name='fc1')(x)  # 全连接层,激活函数elu
x = Dropout(0.6)(x)  # Droupout 0.6
x = Dense(4096, activation='elu', name='fc2')(x)
x = Dropout(0.6)(x)
predictions = Dense(len(classes), activation='softmax', name='predictions')(x)  # 输出层,指定类数

model = Model(input=base_model.input, output=predictions)  # 新网络=预训练网络+自定义网络

model.compile(loss='categorical_crossentropy', optimizer=optimizers.Adam(lr=1e-5), metrics=['accuracy'])
print(model.summary())

train_datagen = ImageDataGenerator(rotation_range=30., shear_range=0.2, zoom_range=0.2,
                                   horizontal_flip=True)  # 30°内随机旋转,0.2几率应用错切,0.2几率缩放内部,水平随机旋转一半图像
train_datagen.mean = np.array([103.939, 116.779, 123.68], dtype=np.float32).reshape((3, 1, 1))  # 去掉imagenet BGR均值
train_data = train_datagen.flow_from_directory(train_dir, target_size=img_size, classes=classes)
validation_datagen = ImageDataGenerator()  # 用于验证,无需数据增强
validation_datagen.mean = np.array([103.939, 116.779, 123.68], dtype=np.float32).reshape((3, 1, 1))
validation_data = validation_datagen.flow_from_directory(validation_dir, target_size=img_size,
                                                         classes=classes)


# 训练&保存
def get_class_weight(d):
    '''
    calculate the weight of each class
    :param d: dir path
    :return: a dict
    '''
    white_list_formats = {'png', 'jpg', 'jpeg', 'bmp'}
    class_number = dict()
    dirs = sorted([o for o in os.listdir(d) if os.path.isdir(os.path.join(d, o))])
    k = 0
    for class_name in dirs:
        class_number[k] = 0
        iglob_iter = glob.iglob(os.path.join(d, class_name, '*.*'))
        for i in iglob_iter:
            _, ext = os.path.splitext(i)
            if ext[1:] in white_list_formats:
                class_number[k] += 1
        k += 1
    total = np.sum(list(class_number.values()))
    max_samples = np.max(list(class_number.values()))
    mu = 1. / (total / float(max_samples))
    keys = class_number.keys()
    class_weight = dict()
    for key in keys:
        score = math.log(mu * total / float(class_number[key]))
        class_weight[key] = score if score > 1. else 1.

    return class_weight


class_weight = get_class_weight(train_dir)  # 计算每个类别所占数据集的比重

early_stopping = EarlyStopping(verbose=1, patience=30, monitor='val_loss')  # 30次微调后loss仍没下降便迭代下一轮
model_checkpoint = ModelCheckpoint(filepath='102flowersmodel.h5', verbose=1, save_best_only=True, monitor='val_loss')
callbacks = [early_stopping, model_checkpoint]

model.fit_generator(train_data, steps_per_epoch=nb_train_samples / float(batch_size), epochs=nb_epoch,
                    validation_data=validation_data, validation_steps=nb_validation_samples / float(batch_size),
                    callbacks=callbacks, class_weight=class_weight)

print('Training is finished!')




三、训练结果

Total params: 134,678,438
Trainable params: 119,963,750
Non-trainable params: 14,714,688
_________________________________________________________________
None
Found 6149 images belonging to 102 classes.
Found 1020 images belonging to 102 classes.
Epoch 1/50
193/192 [==============================] - 139s 721ms/step - loss: 21.1938 - acc: 0.0128 - val_loss: 14.8290 - val_acc: 0.0333
Epoch 2/50
193/192 [==============================] - 126s 652ms/step - loss: 20.5845 - acc: 0.0427 - val_loss: 14.5422 - val_acc: 0.0686

......

Epoch 48/50
193/192 [==============================] - 129s 667ms/step - loss: 5.7738 - acc: 0.6532 - val_loss: 3.6334 - val_acc: 0.6716
Epoch 49/50
193/192 [==============================] - 128s 665ms/step - loss: 5.8061 - acc: 0.6486 - val_loss: 3.4900 - val_acc: 0.6843
Epoch 50/50
193/192 [==============================] - 125s 646ms/step - loss: 5.6267 - acc: 0.6550 - val_loss: 3.3076 - val_acc: 0.6990

由于时间有限仅迭代了50次(原项目默认1000次),花费约150分钟,达到的验证准确率为69.90%。

得到模型
Keras实现预训练网络VGG16迁移学习——102flowers分类【100行代码训练&预测】_第2张图片



四、预测代码

import time
import numpy as np
from keras.models import load_model
from keras.preprocessing import image

# 加载模型
start = time.clock()
model = load_model('102flowers.h5')
print('Warming up took {}s'.format(time.clock() - start))
Warming up took 7.572704s
# 图片预处理
path = 'data/sorted/test/0/image_06736.jpg'
img_height, img_width = 224, 224
x = image.load_img(path=path, target_size=(img_height, img_width))
x = image.img_to_array(x)
x = x[None]

# 预测
start = time.clock()
y = model.predict(x)
print('Prediction took {}s'.format(time.clock() - start))

# 置信度
for i in np.argsort(y[0])[::-1][:5]:
    print('{}:{:.2f}%'.format(i, y[0][i] * 100))
Prediction took 1.1909269999999985s
0:97.38%
86:1.80%
83:0.41%
68:0.35%
67:0.04%




五、参考项目

  1. Arsey/keras-transfer-learning-for-oxford102: Keras pretrained models (VGG16, InceptionV3, Resnet50, Resnet152) + Transfer Learning for predicting classes in the Oxford 102 flower dataset
    https://github.com/Arsey/keras-transfer-learning-for-oxford102

你可能感兴趣的:(Python,Keras)