本次实验的模型是图像分类普遍使用的模型ResNet
。该模型在ImageNet上训练,取得优秀的结果,并且残差思想
被广泛应用于各种任务场景下的网络模型中。当然,ResNet
演化出不同的结构,ResNet-V1
的论文翻译和内容解析,参考链接:ResNet-V1论文翻译及解析。本次实验采用的是ResNet-V2-50
,所以接下来会介绍ResNet-V2
的详细结构。
ResNet-V2
论文链接:https://arxiv.org/pdf/1603.05027.pdf。
本文的核心思想如下:
当使用
identity mappings
时,如果前向参数和反向梯度直接从block传到下一个block,而不用经过ReLU等操作,效果会更好。
ResNet-V1
和ResNet-V2
的异同点,如下图(a)是ResNet-V1,(b)是ResNet-V2:
不同点:(1)基本结构的变化,ResNet-V1的结构是
[Conv+BN+ReLU]+[Conv+BN]
,ResNet-V2的结构为[BN+ReLU+Conv]+[BN+ReLU+Conv]
。(2)addition
操作之后是否添加ReLU
激活函数。
相同点:残差思想未变。
本次实验采用的是ResNet-V2-50结构,每个模块的详细结构如下:
代码结构分为如下几个部分:
import tensorflow as tf
import numpy as np
from tensorflow.contrib.slim.nets import resnet_v2
from tensorflow.contrib.slim.nets import resnet_v1
# from tensorflow.contrib.slim.nets import vgg
import VGG as vgg
from config import cfg
class Model:
def __init__(self):
self.base_architecture = cfg.ResNet.Base_Architecture[1]
self.pre_trained_model = cfg.ResNet.Pre_Trained_Model[1]
self.batch_norm_decay = cfg.ResNet.Batch_Norm_Decay
# 模型的输入
self.inputs = tf.placeholder(tf.float32,
shape=(None, cfg.Train.Input_Size[0], cfg.Train.Input_Size[1], 3),
name='input')
self.label_c = tf.placeholder(tf.int64, shape=(cfg.Train.Batch_Size,), name='label')
self.trainable = tf.placeholder(dtype=tf.bool, name='training')
self.loss_function = cfg.Train.Loss_Function[0]
# 损失函数
with tf.name_scope('loss_function'):
self.logits, self.end_point = self.model(self.trainable,
self.pre_trained_model,
self.base_architecture,
num_classes=cfg.ResNet.Num_Classes)
if self.base_architecture in ['resnet_v1_50', 'resnet_v2_50', 'resnet_v2_101']:
logits_squeeze = tf.squeeze(self.logits, axis=[1, 2])
else:
logits_squeeze = self.logits
one_hot = tf.one_hot(self.label_c, cfg.ResNet.Num_Classes)
if self.loss_function == 'softmax':
print('using softmax cross entropy loss function')
loss_net = tf.losses.softmax_cross_entropy(one_hot, logits_squeeze)
else:
print('using sigmoid cross entropy loss function')
loss_net = tf.losses.sigmoid_cross_entropy(one_hot, logits_squeeze, label_smoothing=0.1)
# loss_net = tf.nn.weighted_cross_entropy_with_logits()
# 添加L2正则化项
l2 = tf.add_n([tf.nn.l2_loss(var) for var in tf.trainable_variables()])
self.loss = loss_net + l2*0.0005
# 模型保存的所有变量
self.net_variables = tf.global_variables()
# 计算精度,训练监测
with tf.name_scope('compute_accuracy'):
if self.base_architecture in ['resnet_v1_50', 'resnet_v2_50', 'resnet_v2_101']:
if self.loss_function == 'softmax':
end_point_soft = self.end_point['predictions']
self.end_point_squeeze = tf.squeeze(end_point_soft, axis=[1, 2])
else:
end_point_soft = tf.nn.sigmoid(self.logits)
self.end_point_squeeze = tf.squeeze(end_point_soft, axis=[1, 2])
else:
self.end_point_squeeze = self.end_point['vgg_16/fc8']
# 计算准确度
correct_prediction = tf.equal(tf.argmax(self.end_point_squeeze, 1), self.label_c)
self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))
# 计算每次训练的混淆矩阵, 统计灵敏度和特异度的训练过程
confusion_matrix = tf.confusion_matrix(self.label_c, tf.argmax(self.end_point_squeeze, 1), num_classes=2)
TN = confusion_matrix[0][0]
FP = confusion_matrix[0][1]
FN = confusion_matrix[1][0]
TP = confusion_matrix[1][1]
# acc = (TP + TN) / (TP + TN + FP + FN)
self.sensitive = TP / (TP + FN)
self.specify = TN / (TN + FP)
# 设置学习率
with tf.name_scope('learning_rate'):
self.global_step = tf.Variable(1.0, dtype=tf.float64, trainable=False, name='global_step')
warmup_steps = tf.constant(cfg.Train.Warmup_Epochs * cfg.Train.Step_Per_Epoch,
dtype=tf.float64,
name='warmup_steps')
# 总的训练步数
train_steps = tf.constant((cfg.Train.First_Stage_Epochs +
cfg.Train.Second_Stage_Epochs) * cfg.Train.Step_Per_Epoch,
dtype=tf.float64,
name='train_steps')
self.learn_rate = tf.cond(pred=self.global_step < warmup_steps,
true_fn=lambda: self.global_step / warmup_steps * cfg.Train.Learn_Rate_Init,
false_fn=lambda: cfg.Train.Learn_Rate_End + 0.5 * (
cfg.Train.Learn_Rate_Init - cfg.Train.Learn_Rate_End) *
(1 + tf.cos((self.global_step - warmup_steps) / (
train_steps - warmup_steps) * np.pi))
)
# self.learn_rate = tf.train.exponential_decay(cfg.Train.Learn_Rate_Init,
# self.global_step,
# decay_steps=400,
# decay_rate=0.9)
global_step_update = tf.assign_add(self.global_step, 1.0)
# 添加移动均值
with tf.name_scope("define_weight_decay"):
moving_ave = tf.train.ExponentialMovingAverage(cfg.ResNet.Moving_Ave_Decay).apply(tf.trainable_variables())
# 针对迁移学习,分阶段训练:优化最后一层,其它层参数固定
with tf.name_scope("first_train_stage"):
# 存储第一阶段需要优化的参数
self.first_stage_trainable_var_list = []
for var in tf.trainable_variables():
var_name = var.op.name
# print('var_name: ', var_name)
var_name_mess = str(var_name).split('/')
print(var_name_mess)
if self.base_architecture in ['resnet_v1_50', 'resnet_v2_50', 'resnet_v2_101']:
if var_name_mess[1] in ['logits']:
self.first_stage_trainable_var_list.append(var)
else:
if var_name_mess[1] in ['fc6', 'fc7', 'fc8']:
self.first_stage_trainable_var_list.append(var)
first_stage_optimizer = tf.train.AdamOptimizer(self.learn_rate).minimize(self.loss,
var_list=self.first_stage_trainable_var_list)
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
with tf.control_dependencies([first_stage_optimizer, global_step_update]):
with tf.control_dependencies([moving_ave]):
self.train_op_with_frozen_variables = tf.no_op()
# self.train_op_with_frozen_variables = tf.group(moving_ave)
# 第二阶段,优化全部参数
with tf.name_scope("second_train_stage"):
second_stage_trainable_var_list = tf.trainable_variables()
second_stage_optimizer = tf.train.AdamOptimizer(self.learn_rate).minimize(self.loss,
var_list=second_stage_trainable_var_list)
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
with tf.control_dependencies([second_stage_optimizer, global_step_update]):
with tf.control_dependencies([moving_ave]):
self.train_op_with_all_variables = tf.no_op()
# self.train_op_with_all_variables = tf.group(moving_ave)
# 模型的恢复与保存
with tf.name_scope('loader_and_saver'):
# 恢复除去最后一层的所有训练参数
variables_to_restore = []
for v in self.net_variables:
print("OP Name: ", v.op.name)
if self.base_architecture in ['resnet_v1_50', 'resnet_v2_50', 'resnet_v2_101']:
if v.name.split('/')[1] not in ['logits']:
print('v: ', v)
variables_to_restore.append(v)
else:
print('vgg model')
if v.name.split('/')[1] not in ['fc6', 'fc7', 'fc8']:
# print('v: ', v)
variables_to_restore.append(v)
print('==================================')
# self.loader = tf.train.Saver(self.net_var)
self.loader = tf.train.Saver(variables_to_restore)
self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=300)
# 收集训练过程中相关指标,tensorboard可视化训练过程
with tf.name_scope('collect_summary'):
tf.summary.scalar('loss', self.loss)
tf.summary.scalar('accuracy', self.accuracy)
tf.summary.scalar('learning rate', self.learn_rate)
tf.summary.scalar('sensitive', self.sensitive)
tf.summary.scalar('specify', self.specify)
self.merged = tf.summary.merge_all()
# 调用官方定义的模型
def model(self, is_training, pre_trained_model, base_architecture, num_classes):
"""
load network structure
:param is_training: 是否训练的标志
:param pre_trained_model: 预训练权重
:param base_architecture: 网络名字
:param num_classes: 分类数量
:return: 网络最后一层,以及所有层构成的字典
"""
if base_architecture in ['resnet_v1_50', 'resnet_v2_50', 'resnet_v2_101']:
print('using ', base_architecture)
batch_norm_decay = 0.997
base_model = resnet_v2.resnet_v2_50
with tf.contrib.slim.arg_scope(resnet_v2.resnet_arg_scope(batch_norm_decay=batch_norm_decay)):
logits, end_points = base_model(self.inputs, num_classes=num_classes, is_training=is_training)
else:
print('using vgg16')
weight_decay = 0.0005
base_model = vgg.vgg_16
with tf.contrib.slim.arg_scope(vgg.vgg_arg_scope(weight_decay=weight_decay)):
logits, end_points = base_model(self.inputs,
num_classes=num_classes,
is_training=is_training,
spatial_squeeze=is_training)
# if is_training:
# exclude = [base_architecture + '/logits', 'global_step'] # , 'global_step'
# print('exclude: ', exclude)
#
# variables_to_restore = tf.contrib.slim.get_variables_to_restore(exclude=exclude)
# print("variables to restore: ", variables_to_restore)
# #
# tf.train.init_from_checkpoint(pre_trained_model,
# {v.name.split(':')[0]: v for v in variables_to_restore})
# print('restore form pretrained model\n')
return logits, end_points