迁移学习resnet_AI实战:迁移学习之使用ResNet做分类

迁移学习包括:

1、Feature Extraction

2、Fine-Tuning

本文基于 tensorflow2.0,使用 cats_vs_dog 数据集,应用 tf.keras.applications 创建 base model,使用 ResNet101 做 Feature Extraction。

核心代码:

'''

Transfer learning with a pretrained ConvNet: resnet101

参考:

https://tensorflow.google.cn

下载模型位置:

~/.keras/models/resnet101_weights_tf_dim_ordering_tf_kernels_notop.h5

'''

from __future__ import absolute_import, division, print_function, unicode_literals

import os

import numpy as np

import matplotlib.pyplot as plt

import tensorflow as tf

keras = tf.keras

#dataset

#Data preprocessing

#Data download

#Use TensorFlow Datasets to load the cats and dogs dataset.

import tensorflow_datasets as tfds

tfds.disable_progress_bar()

SPLIT_WEIGHTS = (8, 1, 1)

splits = tfds.Split.TRAIN.subsplit(weighted=SPLIT_WEIGHTS)

(raw_train, raw_validation, raw_test), metadata = tfds.load(

'cats_vs_dogs', split=list(splits),

with_info=True, as_supervised=True)

print(raw_train)

print(raw_validation)

print(raw_test)

get_label_name = metadata.features['label'].int2str

for image, label in raw_train.take(2):

plt.figure()

plt.imshow(image)

plt.title(get_label_name(label))

#Format the Data

IMG_SIZE = 160 # All images will be resized to 160x160

def format_example(image, label):

image = tf.cast(image, tf.float32)

image = (image/127.5) - 1

image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))

return image, label

#shuffle and batch the data

train = raw_train.map(format_example)

validation = raw_validation.map(format_example)

test = raw_test.map(format_example)

BATCH_SIZE = 2#32

SHUFFLE_BUFFER_SIZE = 1000

train_batches = train.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)

validation_batches = validation.batch(BATCH_SIZE)

test_batches = test.batch(BATCH_SIZE)

for image_batch, label_batch in train_batches.take(1):

pass

print(image_batch.shape)

#Create the base model

IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)

# Create the base model from the pre-trained model ResNet101

base_model = tf.keras.applications.ResNet101(input_shape=IMG_SHAPE,

include_top=False,

weights='imagenet')

feature_batch = base_model(image_batch)

print(feature_batch.shape)

base_model.trainable = False

base_model.summary()

#Add a classification head

global_average_layer = tf.keras.layers.GlobalAveragePooling2D()

feature_batch_average = global_average_layer(feature_batch)

print(feature_batch_average.shape)

prediction_layer = keras.layers.Dense(1)

prediction_batch = prediction_layer(feature_batch_average)

print(prediction_batch.shape)

model = tf.keras.Sequential([

base_model,

global_average_layer,

prediction_layer ])

#Compile the model

base_learning_rate = 0.0001

model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=base_learning_rate),

loss='binary_crossentropy',

metrics=['accuracy'])

model.summary()

#Train the model

num_train, num_val, num_test = (

metadata.splits['train'].num_examples*weight/10

for weight in SPLIT_WEIGHTS )

initial_epochs = 1#10

history = model.fit(train_batches,

epochs=initial_epochs,

validation_data=validation_batches)

# Save weights to a HDF5 file

model.save_weights('transfer_learning-resnet101-model-cats-dogs.h5', save_format='h5')

# Restore the model's state

#model.load_weights('my_model.h5')

#Learning curves

acc = history.history['accuracy']

val_acc = history.history['val_accuracy']

loss = history.history['loss']

val_loss = history.history['val_loss']

plt.figure(figsize=(8, 8))

plt.subplot(2, 1, 1)

plt.plot(acc, label='Training Accuracy')

plt.plot(val_acc, label='Validation Accuracy')

plt.legend(loc='lower right')

plt.ylabel('Accuracy')

plt.ylim([min(plt.ylim()),1])

plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)

plt.plot(loss, label='Training Loss')

plt.plot(val_loss, label='Validation Loss')

plt.legend(loc='upper right')

plt.ylabel('Cross Entropy')

plt.ylim([0,1.0])

plt.title('Training and Validation Loss')

plt.xlabel('epoch')

plt.show()

过程输出

1、cats_vs_dog 数据样式

2、学习曲线

3、训练1个epoch的结果

4、模型保存到 ./transfer_learning-resnet101-model-cats-dogs.h5

你可能感兴趣的:(迁移学习resnet)