ResNet迁移学习(三)—— 网络结构(Model.py)

基本简介

  本次实验的模型是图像分类普遍使用的模型ResNet。该模型在ImageNet上训练,取得优秀的结果,并且残差思想被广泛应用于各种任务场景下的网络模型中。当然,ResNet演化出不同的结构,ResNet-V1的论文翻译和内容解析,参考链接:ResNet-V1论文翻译及解析。本次实验采用的是ResNet-V2-50,所以接下来会介绍ResNet-V2的详细结构。

ResNet-V2

ResNet-V2论文链接:https://arxiv.org/pdf/1603.05027.pdf。
本文的核心思想如下:

当使用identity mappings时,如果前向参数和反向梯度直接从block传到下一个block,而不用经过ReLU等操作,效果会更好。

ResNet迁移学习(三)—— 网络结构(Model.py)_第1张图片

ResNet-V1ResNet-V2的异同点,如下图(a)是ResNet-V1,(b)是ResNet-V2:

不同点:(1)基本结构的变化,ResNet-V1的结构是[Conv+BN+ReLU]+[Conv+BN],ResNet-V2的结构为[BN+ReLU+Conv]+[BN+ReLU+Conv]。(2)addition操作之后是否添加ReLU激活函数。
相同点:残差思想未变。

ResNet迁移学习(三)—— 网络结构(Model.py)_第2张图片

ResNet-V2-50

本次实验采用的是ResNet-V2-50结构,每个模块的详细结构如下:
ResNet迁移学习(三)—— 网络结构(Model.py)_第3张图片

代码展示

代码结构分为如下几个部分:

  1. 训练的基本参数以及网络输入(placeholder)
  2. 损失函数(导入模型,计算损失)
  3. 精度计算(根据需要,选择要计算的指标)
  4. 分阶段训练(优化器以及待优化参数设定)
  5. 学习率设置
  6. 训练日志
  7. 模型的保存与恢复(恢复模型特定层的权重参数)
import tensorflow as tf
import numpy as np

from tensorflow.contrib.slim.nets import resnet_v2
from tensorflow.contrib.slim.nets import resnet_v1
# from tensorflow.contrib.slim.nets import vgg
import VGG as vgg

from config import cfg

class Model:
    def __init__(self):
        self.base_architecture = cfg.ResNet.Base_Architecture[1]
        self.pre_trained_model = cfg.ResNet.Pre_Trained_Model[1]
        self.batch_norm_decay = cfg.ResNet.Batch_Norm_Decay
		
		# 模型的输入
        self.inputs = tf.placeholder(tf.float32,
                                     shape=(None, cfg.Train.Input_Size[0], cfg.Train.Input_Size[1], 3),
                                     name='input')
        self.label_c = tf.placeholder(tf.int64, shape=(cfg.Train.Batch_Size,), name='label')
        self.trainable = tf.placeholder(dtype=tf.bool, name='training')

        self.loss_function = cfg.Train.Loss_Function[0]
		
		# 损失函数
        with tf.name_scope('loss_function'):
            self.logits, self.end_point = self.model(self.trainable,
                                                     self.pre_trained_model,
                                                     self.base_architecture,
                                                     num_classes=cfg.ResNet.Num_Classes)

            if self.base_architecture in ['resnet_v1_50', 'resnet_v2_50', 'resnet_v2_101']:
                logits_squeeze = tf.squeeze(self.logits, axis=[1, 2])
            else:
                logits_squeeze = self.logits

            one_hot = tf.one_hot(self.label_c, cfg.ResNet.Num_Classes)
            if self.loss_function == 'softmax':
                print('using softmax cross entropy loss function')
                loss_net = tf.losses.softmax_cross_entropy(one_hot, logits_squeeze)
            else:
                print('using sigmoid cross entropy loss function')
                loss_net = tf.losses.sigmoid_cross_entropy(one_hot, logits_squeeze, label_smoothing=0.1)
                # loss_net = tf.nn.weighted_cross_entropy_with_logits()

            # 添加L2正则化项
            l2 = tf.add_n([tf.nn.l2_loss(var) for var in tf.trainable_variables()])
            self.loss = loss_net + l2*0.0005

            # 模型保存的所有变量
            self.net_variables = tf.global_variables()
		
		# 计算精度,训练监测
        with tf.name_scope('compute_accuracy'):
            if self.base_architecture in ['resnet_v1_50', 'resnet_v2_50', 'resnet_v2_101']:
                if self.loss_function == 'softmax':
                    end_point_soft = self.end_point['predictions']
                    self.end_point_squeeze = tf.squeeze(end_point_soft, axis=[1, 2])
                else:
                    end_point_soft = tf.nn.sigmoid(self.logits)
                    self.end_point_squeeze = tf.squeeze(end_point_soft, axis=[1, 2])
            else:
                self.end_point_squeeze = self.end_point['vgg_16/fc8']

            # 计算准确度
            correct_prediction = tf.equal(tf.argmax(self.end_point_squeeze, 1), self.label_c)
            self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))

            # 计算每次训练的混淆矩阵, 统计灵敏度和特异度的训练过程
            confusion_matrix = tf.confusion_matrix(self.label_c, tf.argmax(self.end_point_squeeze, 1), num_classes=2)

            TN = confusion_matrix[0][0]
            FP = confusion_matrix[0][1]
            FN = confusion_matrix[1][0]
            TP = confusion_matrix[1][1]

            # acc = (TP + TN) / (TP + TN + FP + FN)
            self.sensitive = TP / (TP + FN)
            self.specify = TN / (TN + FP)
		
		# 设置学习率
        with tf.name_scope('learning_rate'):
            self.global_step = tf.Variable(1.0, dtype=tf.float64, trainable=False, name='global_step')
            warmup_steps = tf.constant(cfg.Train.Warmup_Epochs * cfg.Train.Step_Per_Epoch,
                                       dtype=tf.float64,
                                       name='warmup_steps')

            # 总的训练步数
            train_steps = tf.constant((cfg.Train.First_Stage_Epochs +
                                       cfg.Train.Second_Stage_Epochs) * cfg.Train.Step_Per_Epoch,
                                      dtype=tf.float64,
                                      name='train_steps')

            self.learn_rate = tf.cond(pred=self.global_step < warmup_steps,
                                      true_fn=lambda: self.global_step / warmup_steps * cfg.Train.Learn_Rate_Init,
                                      false_fn=lambda: cfg.Train.Learn_Rate_End + 0.5 * (
                                              cfg.Train.Learn_Rate_Init - cfg.Train.Learn_Rate_End) *
                                                       (1 + tf.cos((self.global_step - warmup_steps) / (
                                                               train_steps - warmup_steps) * np.pi))
                                      )
            
            # self.learn_rate = tf.train.exponential_decay(cfg.Train.Learn_Rate_Init,
            #                                              self.global_step,
            #                                              decay_steps=400,
            #                                              decay_rate=0.9)
            global_step_update = tf.assign_add(self.global_step, 1.0)
		
		# 添加移动均值
        with tf.name_scope("define_weight_decay"):
            moving_ave = tf.train.ExponentialMovingAverage(cfg.ResNet.Moving_Ave_Decay).apply(tf.trainable_variables())
		
		# 针对迁移学习,分阶段训练:优化最后一层,其它层参数固定
        with tf.name_scope("first_train_stage"):
            # 存储第一阶段需要优化的参数
            self.first_stage_trainable_var_list = []
            for var in tf.trainable_variables():
                var_name = var.op.name

                # print('var_name: ', var_name)
                var_name_mess = str(var_name).split('/')
                print(var_name_mess)

                if self.base_architecture in ['resnet_v1_50', 'resnet_v2_50', 'resnet_v2_101']:
                    if var_name_mess[1] in ['logits']:
                        self.first_stage_trainable_var_list.append(var)
                else:
                    if var_name_mess[1] in ['fc6', 'fc7', 'fc8']:
                        self.first_stage_trainable_var_list.append(var)

            first_stage_optimizer = tf.train.AdamOptimizer(self.learn_rate).minimize(self.loss,
                                                                                     var_list=self.first_stage_trainable_var_list)
            with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                with tf.control_dependencies([first_stage_optimizer, global_step_update]):
                    with tf.control_dependencies([moving_ave]):
                        self.train_op_with_frozen_variables = tf.no_op()
                    # self.train_op_with_frozen_variables = tf.group(moving_ave)
		
		# 第二阶段,优化全部参数
        with tf.name_scope("second_train_stage"):
            second_stage_trainable_var_list = tf.trainable_variables()
            second_stage_optimizer = tf.train.AdamOptimizer(self.learn_rate).minimize(self.loss,
                                                                                      var_list=second_stage_trainable_var_list)

            with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                with tf.control_dependencies([second_stage_optimizer, global_step_update]):
                    with tf.control_dependencies([moving_ave]):
                        self.train_op_with_all_variables = tf.no_op()
                    # self.train_op_with_all_variables = tf.group(moving_ave)
		
		# 模型的恢复与保存
        with tf.name_scope('loader_and_saver'):
            # 恢复除去最后一层的所有训练参数
            variables_to_restore = []
            for v in self.net_variables:
                print("OP Name: ", v.op.name)
                if self.base_architecture in ['resnet_v1_50', 'resnet_v2_50', 'resnet_v2_101']:
                    if v.name.split('/')[1] not in ['logits']:
                        print('v: ', v)
                        variables_to_restore.append(v)
                else:
                    print('vgg model')
                    if v.name.split('/')[1] not in ['fc6', 'fc7', 'fc8']:
                        # print('v: ', v)
                        variables_to_restore.append(v)
            print('==================================')
            # self.loader = tf.train.Saver(self.net_var)
            self.loader = tf.train.Saver(variables_to_restore)

            self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=300)
		
		# 收集训练过程中相关指标,tensorboard可视化训练过程
        with tf.name_scope('collect_summary'):
            tf.summary.scalar('loss', self.loss)
            tf.summary.scalar('accuracy', self.accuracy)
            tf.summary.scalar('learning rate', self.learn_rate)
            tf.summary.scalar('sensitive', self.sensitive)
            tf.summary.scalar('specify', self.specify)

            self.merged = tf.summary.merge_all()
	
	# 调用官方定义的模型
    def model(self, is_training, pre_trained_model, base_architecture, num_classes):
        """
        load network structure
        :param is_training: 是否训练的标志
        :param pre_trained_model: 预训练权重
        :param base_architecture: 网络名字
        :param num_classes: 分类数量
        :return: 网络最后一层,以及所有层构成的字典
        """
        if base_architecture in ['resnet_v1_50', 'resnet_v2_50', 'resnet_v2_101']:
            print('using ', base_architecture)
            batch_norm_decay = 0.997

            base_model = resnet_v2.resnet_v2_50
            with tf.contrib.slim.arg_scope(resnet_v2.resnet_arg_scope(batch_norm_decay=batch_norm_decay)):
                logits, end_points = base_model(self.inputs, num_classes=num_classes, is_training=is_training)
        else:
            print('using vgg16')
            weight_decay = 0.0005
            base_model = vgg.vgg_16
            with tf.contrib.slim.arg_scope(vgg.vgg_arg_scope(weight_decay=weight_decay)):
                logits, end_points = base_model(self.inputs,
                                                num_classes=num_classes,
                                                is_training=is_training,
                                                spatial_squeeze=is_training)

        # if is_training:
        #     exclude = [base_architecture + '/logits', 'global_step']  # , 'global_step'
        #     print('exclude: ', exclude)
        #
        #     variables_to_restore = tf.contrib.slim.get_variables_to_restore(exclude=exclude)
        #     print("variables to restore: ", variables_to_restore)
        #     #
        #     tf.train.init_from_checkpoint(pre_trained_model,
        #                                   {v.name.split(':')[0]: v for v in variables_to_restore})
        #     print('restore form pretrained model\n')

        return logits, end_points

你可能感兴趣的:(图像分类迁移学习,resnet-v2,迁移学习,tensorflow,人工智能)