机器学习学习笔记--迁移学习

迁移学习,故名思意,就是将已经训练完成模型的一部分参数迁移到新的模型中来。我们都知道,训练一个复杂的神经网络是非常费时费力的过程。往往我们需要收集大量已经标注好的训练集,使用昂贵的 GPU 提供算力,再花费不少的时间,才能训练出一个表现不错的模型。

如果两个任务如果存在一定的相关性,比如已经训练好的一个非常优秀的猫狗识别模型,现在需要完成猫兔识别。我们就无需从头开始训练新模型,可以将原模型中的一些权值和中间层迁移到新模型中,这样就大大降低了我们需要的算力,节省了训练时间。

除此之外,对于刚刚接触深度学习的朋友来讲,由于自身经验缺失,大多是没有能力训练一个较好的模型。如果能将大牛训练后的模型加以迁移,那么很快就可以可以让模型满足自己的需求。

迁移学习是机器学习中的非常重要,且值得深入研究的领域。反观人类本身,我们其实就拥有强大的迁移学习能力。经常听到的「举一反三」,也就是对迁移学习的最好印证。

# -*- coding: utf-8 -*-

import os

import glob

from keras.applications.inception_v3 import InceptionV3

from keras.models import Model

from keras.layers import Dense, GlobalAveragePooling2D

from keras.preprocessing.image import ImageDataGenerator

# 读取文件数量和分类数量

train_num = 0

for r, dirs, files in os.walk('train_dir'):

for dr in dirs:

train_num += len(glob.glob(os.path.join(r, dr + "/*")))

val_num = 0

for r, dirs, files in os.walk('val_dir'):

for dr in dirs:

val_num += len(glob.glob(os.path.join(r, dr + "/*")))

# 图像数据预处理

datagen = ImageDataGenerator(rotation_range=40,

width_shift_range=0.2,

height_shift_range=0.2,

shear_range=0.2,

zoom_range=0.2,

horizontal_flip=True,

fill_mode='nearest')

train_generator = datagen.flow_from_directory('train_dir',

target_size=(299, 299),

batch_size=32)

validation_generator = datagen.flow_from_directory('val_dir',

target_size=(299, 299),

batch_size=32)

# 冻结原模型全连接层

base_model = InceptionV3(weights='imagenet', include_top=False)

# 重新配置全连接层

x = base_model.output

x = GlobalAveragePooling2D()(x)

x = Dense(1024, activation='relu')(x)

predictions = Dense(2, activation='softmax')(x)

model = Model(inputs=base_model.input, outputs=predictions)

# 重新训练迁移模型

for layer in base_model.layers:

layer.trainable = False

model.compile(optimizer='rmsprop',

loss='categorical_crossentropy',

metrics=['accuracy'])

model.fit_generator(train_generator,

epochs=2,

steps_per_epoch=train_num,

validation_data=validation_generator,

validation_steps=val_num,

class_weight='auto')

# 将模型权重保存

model.save_weights('inceptionv3-tl.h5')

# 将模型保存

model.save('inceptionv3-tl.model')

你可能感兴趣的:(机器学习学习笔记--迁移学习)