keras断点训练

有时模型训练时间过长,如遇到训练异常中断或者想调整超参,需要继续上次训练而不是重头开始。这里参照tf-slim的理念实现一种keras断点训练的功能:在做fine-tune时,如果保存模型路径中没有已保存的模型参数,则从google预训练模型中恢复参数,如果保存模型路径中有已保存的模型参数(之前已经训练过),则从保存的参数恢复模型参数继续上次训练。

下面先总结了模型参数初始化的几种方式(有基础的直接跳过),

一、模型参数初始化方式

1、随机初始化模型参数

base_model = VGG16(weights='None',include_top=False,input_shape=(224,224,3))

weights设置为None时,会随机初始化模型参数

2、从保存的模型参数中初始化变量

保存的模型参数可以是google的预训练模型,可以再这里下载google的预训练模型;也可以是自己保存的模型参数。

1)从自己保存的模型参数初始化变量

使用model.load_weights(weights_path)恢复模型参数

checkpoint_dir = '/data/sfang/logo_classify/keras_model/checkpoint/best.hdf5'
if os.path.exists(checkpoint_dir):
    sys.stdout.write('INFO:checkpoint exists, Load weights from %s\n'%checkpoint_dir)
    model.load_weights(checkpoint_dir)
else:
    sys.stdout.write('No checkpoint found')

2)Google预训练模型初始化参数

base_model = VGG16(weights='imagenet',include_top=False,input_shape=(224,224,3))

设置weights='imagenet',通过查看以下VGG16的源码知,设置weights='imagenet'也是使用model.load_weights(weights_path)恢复模型参数

 # load weights
    if weights == 'imagenet':
        if include_top:
            weights_path = get_file('vgg16_weights_tf_dim_ordering_tf_kernels.h5',
                                    WEIGHTS_PATH,
                                    cache_subdir='models',
                                    file_hash='64373286793e3c8b2b4e3219cbf3544b')
        else:
            weights_path = get_file('vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5',
                                    WEIGHTS_PATH_NO_TOP,
                                    cache_subdir='models',
                                    file_hash='6d6bbae143d832006294945121d1f1fc')
        model.load_weights(weights_path)

注意:

1)在只想恢复部分层参数,例如做fine-tune时,在VGG16的base_model后加了几层,只想使用google的预训练模型初始化base_model的参数时,使用model.load_weights(weights_path)会报错模型的层数与参数层数不相等,只需要指定按照层名称初始化参数就可以解决。model.load_weights(weights_path,by_name=True)

2)除model.load_weights(weights_path)外,也可以自己逐层初始化模型参数

for i in range(len(model_ pretrained.layers)-1):
    model_new.layers[i].set_weights(model_pretrained.layers[i].get_weights())

二、断点训练

以VGG16 fine tune为例,在进行训练时,如果有已训练的模型参数保存,则从该文件中初始化模型参数继续上次训练,如果没有则从google的预训练模型初始化参数。

断点训练代码:

weights = 'imagenet'
include_top = False

# load weights
if os.path.exists(checkpoint_dir):
    sys.stdout.write('INFO:checkpoint exists, Load weights from %s\n'%checkpoint_dir)
    model.load_weights(checkpoint_dir)

elif weights == 'imagenet':
    sys.stdout.write('INFO:Load weights from imagenet\n')
    if include_top:
        weights_path = get_file('vgg16_weights_tf_dim_ordering_tf_kernels.h5',
                                WEIGHTS_PATH,
                                cache_subdir='models',
                                file_hash='64373286793e3c8b2b4e3219cbf3544b')
    else:
        weights_path = get_file('vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5',
                                WEIGHTS_PATH_NO_TOP,
                                cache_subdir='models',
                                file_hash='6d6bbae143d832006294945121d1f1fc')
    model.load_weights(weights_path,by_name=True)

数据组织形式:

文件目录:

keras断点训练_第1张图片

其中0,1,2,3,4存储的是每类的照片。本例总共是五类。

项目全部代码: 

#coding=utf-8
import keras
import os
import glob
import sys
import argparse
import tensorflow as tf

from matplotlib import pyplot as plt
from keras.applications.vgg16 import VGG16, preprocess_input
from keras.models import Model
from keras.utils import get_file
from keras.layers import Dense,GlobalAveragePooling2D,Flatten,Dropout,Conv2D
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import SGD,RMSprop
from keras.callbacks import TensorBoard
from keras.callbacks import ModelCheckpoint
from sklearn.metrics import classification_report
from keras.backend.tensorflow_backend import set_session

# config = tf.ConfigProto()
# config.gpu_options.per_process_gpu_memory_fraction = 0.4
# set_session(tf.Session(config=config))

IM_WIDTH, IM_HEIGHT = 299, 299 #InceptionV3指定的图片尺寸
FC_SIZE = 1024                # 全连接层的节点个数
NB_IV3_LAYERS_TO_FREEZE = 172  # 冻结层的数量
#训练集和测试集路径
train_dir = '/data/sfang/logo_classify/data/image_preprocessed/images/train/'
test_dir = '/data/sfang/logo_classify/data/image_preprocessed/images/test/'
val_dir = '/data/sfang/logo_classify/data/image_classify/val/'
checkpoint_dir = '/data/sfang/logo_classify/keras_model/checkpoint/best.hdf5'

WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels.h5'
WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5'

#使用ImageDataGenerator类定义训练数据生产器
train_datagen = ImageDataGenerator(
    #preprocessing_function=preprocess_input, #图像预处理函数
    #rotation_range=30,                       #旋转角度范围
    #width_shift_range=0.2,                   #
    #height_shift_range=.2,                   #
    #shear_range=.2,
    #zoom_range=.2,
    #horizontal_flip=True,
    rescale=1./255
)
test_datagen = ImageDataGenerator(
    #preprocessing_function=preprocess_input, #图像预处理函数
    # rotation_range=30,                       #旋转角度范围
    # width_shift_range=0.2,                   #
    # height_shift_range=.2,                   #
    # shear_range=.2,
    # zoom_range=.2,
    # horizontal_flip=True,
    rescale=1./255
)
validation_datagen = ImageDataGenerator(
    #preprocessing_function=preprocess_input, #图像预处理函数
    # rotation_range=30,                       #旋转角度范围
    # width_shift_range=0.2,                   #
    # height_shift_range=.2,                   #
    # shear_range=.2,
    # zoom_range=.2,
    # horizontal_flip=True,
    rescale=1./255
)

train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(224,224),
    batch_size=64,
    class_mode='categorical',
)
validation_generator = validation_datagen.flow_from_directory(
    val_dir,
    target_size=(224,224),
    batch_size=16,
    class_mode='categorical'
)
test_generator = test_datagen.flow_from_directory(
    test_dir,
    target_size=(224, 224),
    batch_size=16,
    class_mode='categorical'
)


def classify_model(base_model,nums_classes):
    x = base_model.output
    x = Conv2D(4096,(7,7),activation='relu',padding='valid',name='fc5')(x)
    x = Flatten()(x)
    x = Dense(4096, activation='relu', name='fc6')(x)
    x = Dropout(rate=.5, name='dropout6')(x)
    x = Dense(4096, activation='relu', name='fc7')(x)
    x = Dropout(rate=.5, name='dropout7')(x)
    predictions = Dense(nums_classes,activation='softmax',name='fc8')(x)
    model = Model(input=base_model.input, output=predictions,name='my_vgg16')
    return model

def define_trainable_layers(model,base_model):
    # for layer in base_model.layers:
    #     layer.trainable = False
    opt = RMSprop(lr=1e-4)
    model.compile(optimizer=opt,loss='categorical_crossentropy',metrics=['accuracy'])

#定义模型
#bottleneck=vgg16
# 为避免重复初始化,不在这里进行初始化
base_model = VGG16(weights=None,include_top=False,input_shape=(224,224,3))
#定义模型,使用全局平均池化代替全连接层
model = classify_model(base_model,5)

weights = 'imagenet'
include_top = False

# load weights
# 如有已保存的模型参数则继续上次训练,如没有则从头开始训练
if os.path.exists(checkpoint_dir):
    sys.stdout.write('INFO:checkpoint exists, Load weights from %s\n'%checkpoint_dir)
    model.load_weights(checkpoint_dir)

elif weights == 'imagenet':
    sys.stdout.write('INFO:Load weights from imagenet\n')
    if include_top:
        weights_path = get_file('vgg16_weights_tf_dim_ordering_tf_kernels.h5',
                                WEIGHTS_PATH,
                                cache_subdir='models',
                                file_hash='64373286793e3c8b2b4e3219cbf3544b')
    else:
        weights_path = get_file('vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5',
                                WEIGHTS_PATH_NO_TOP,
                                cache_subdir='models',
                                file_hash='6d6bbae143d832006294945121d1f1fc')
    model.load_weights(weights_path,by_name=True)


#定义要训练的网络层
model.summary()
define_trainable_layers(model,base_model)

#监控某一项指标,当在一轮Epoch中该指标变优(loos变低或acc变高)则保存模型
checkpoint = ModelCheckpoint(checkpoint_dir,monitor='val_acc',
                             mode='min',save_best_only=True,verbose=1)
callbacks = [checkpoint,TensorBoard(log_dir='./log')]


history = model.fit_generator(
    train_generator,
    epochs=150,
    shuffle=True,
    callbacks=callbacks,
    steps_per_epoch=1028,
    validation_data=test_generator,
    validation_steps=5
)

 

你可能感兴趣的:(python,深度学习,keras)