AI challenger 场景分类 tensorflow inception-resnet-v2 LB: 0.94361

模型采用tf-slim在imagenet上训练的inception-resnet-v2,可以选择训练哪些层,如只重新训练最后一层,或重新训练后面的多层等等。没有采取特殊的数据增强,用的tf-slim默认的inception输入方式。采用如下参数配置线上得分0.94361。
用的tfrecord图片都是先resize成299*299再转换的,具体方法可参考之前的博文。
learning_rate=0.0001
batch_size=32
num_epochs=80
具体情况:
training accuracy: 0.836019
FInal Testing accuracy: 0.945787(val)
FInal Testing accuracy: 0.94361 (testA)
看起来还是有很大改进(调参)空间的,包括数据增强/分辨率和epoch数等等,但是:

这个代码目前有个问题: 没有实现训练的同时监测验证准确率。这是tensorflow使用tfrecord时的一个坑,需要自己写一些很丑的解决方案,待解决(非常重要,因为已经在一些参数配置上观测到过拟合)。新版本的tf会逐步解决这个问题,详见开头注释的两个issues。采用官方代码提供的图片读取方案则可以简单解决这个问题,但是读取效率可能慢一倍,而且无法在一些云计算平台使用。

# -*- coding: utf-8 -*-
"""
Created on Wed Sep 20 16:05:02 2017

@author: wayne

FEELINGS
目前原生tf和tfrecord的坑还是挺多的,需要自己写的“通用代码”较多,尤其是input pipeline和训练/验证的【流程控制和监控准确率】等
已经在最新的1.3版本中引入了datasets,未来的1.4版本特性参见
https://github.com/tensorflow/tensorflow/issues/7902
和
https://github.com/tensorflow/tensorflow/issues/7951
目前来看,其实还是PyTorch好用,代码更直观易懂

使用原生tf的各种模块结合slim模型。可以考虑学习使用slim官方的样板代码,不过抽象程度较高。

CHANGES
- 可以restore我们自己上次的存档模型,而不是每次都从官方模型开始训练: tf.flags.DEFINE_bool('use_official', True)
- 

REFERENCES

https://web.stanford.edu/class/cs20si/syllabus.html

输入数据
https://stackoverflow.com/questions/44054656/creating-tfrecords-from-a-list-of-strings-and-feeding-a-graph-in-tensorflow-afte
https://indico.io/blog/tensorflow-data-inputs-part1-placeholders-protobufs-queues/
https://indico.io/blog/tensorflow-data-input-part2-extensions/

整个架构
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/udacity/2_fullyconnected.ipynb

模型的存储和调用
http://blog.csdn.net/u014595019/article/details/53912710
http://blog.csdn.net/u012436149/article/details/52883747 (restore变量的子集)
https://github.com/SymphonyPy/Valified_Code_Classify/tree/master/Classified
http://blog.csdn.net/lwplwf/article/details/76177296 (定义了一个loop,去监听,一旦有新的checkpoint生成,就去执行一次验证。)

迁移学习(使用tf原生模块结合slim cnn模型的教程真少!)
https://github.com/AIChallenger/AI_Challenger/tree/master/Baselines/caption_baseline (用的slim cnn)
https://github.com/kwotsin/transfer_learning_tutorial (较为完整的程序,但是使用的都是slim提供的模块,还使用了tf.train.Supervisor和tensorboard)
http://blog.csdn.net/ArtistA/article/details/52860050 (用tf直接实现的cnn): https://github.com/joelthchao/tensorflow-finetune-flickr-style
http://blog.csdn.net/nnnnnnnnnnnny/article/details/70244232 (tensorflow_inception_graph.pb。因为一个训练数据会被使用多次,所以可以将原始图像通过Inception-v3模型计算得到的特征向量保存在文件中,免去重复的计算。)
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/image_retraining/retrain.py 

https://github.com/tensorflow/models/issues/391         [slim] weird result with parameter is_training
https://github.com/YanWang2014/models/tree/master/slim  (slim的各种模型)
http://pytorch.org/docs/master/torchvision/models.html
http://data.mxnet.io/models/

数据增强
https://github.com/wzhang1/iNaturalist   MXNet finetune baseline (res152) for challenger.ai/competition/scene
https://github.com/AIChallenger/AI_Challenger/tree/master/Baselines/caption_baseline/im2txt/im2txt/ops

调参
https://zhuanlan.zhihu.com/p/22252270    深度学习最全优化方法总结比较(SGD,Adagrad,Adadelta,Adam,Adamax,Nadam) 
http://www.360doc.com/content/16/1010/08/36492363_597225745.shtml

https://www.zhihu.com/question/41631631  你有哪些deep learning(rnn、cnn)调参的经验?
https://www.zhihu.com/question/25097993  深度学习调参有哪些技巧?
https://www.zhihu.com/question/24529483  在神经网络中weight decay起到的做用是什么?momentum呢?normalization呢?
https://zhuanlan.zhihu.com/p/27555858?utm_medium=social&utm_source=wechat_session  [科普]如何使用高大上的方法调参数

tfrecord验证集问题:在是否额外建立graph方面有很多幺蛾子方法
https://github.com/tensorflow/tensorflow/issues/7902    每次验证要恰好读完整个验证集,且要读多次,在用tfrecord时怎么(优雅地)实现?
https://github.com/tensorflow/tensorflow/issues/7951    新版本会在input pipeline上做改进
https://stackoverflow.com/questions/39187764/tensorflow-efficient-feeding-of-eval-train-data-using-queue-runners
https://stackoverflow.com/questions/44270198/when-using-tfrecord-how-can-i-run-intermediate-validation-check-a-better-way

https://stackoverflow.com/questions/40146428/show-training-and-validation-accuracy-in-tensorflow-using-same-graph

可视化adamoptimizer的lr
https://stackoverflow.com/questions/36990476/getting-the-current-learning-rate-from-a-tf-train-adamoptimizer/44688307#44688307

"""

from __future__ import division, print_function, absolute_import

import tensorflow as tf
import time
slim = tf.contrib.slim
from inception_resnet_v2 import *
import inception_preprocessing

tf.reset_default_graph()

import os
FLAGS = tf.flags.FLAGS

tf.flags.DEFINE_bool('train_flag', False, 'train_flag')
tf.flags.DEFINE_string('trainable_scopes', 'InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits', '训练的层') #None 为全部训练。测试时不用管
tf.flags.DEFINE_bool('use_official', True, '使用官方模型开始训练还是使用自己存的模型,使用自己模型之前先给模型备份,否则可能会被覆盖掉')

tf.flags.DEFINE_float('learning_rate', 0.001, 'learning_rate')    
tf.flags.DEFINE_string('val_test',  'None', 'train_flag=False时用哪个数据测试: val.tfrecord, testA testB')
#0.1 for the last layer
#1e-3 5e-4。 0.001 for the last layer, 0.0001 for whole0? 0.1 0.05 0.00001

tf.flags.DEFINE_float('beta1', 0.9, 'beta1')
tf.flags.DEFINE_float('beta2', 0.999, 'beta2')
tf.flags.DEFINE_float('epsilon', 0.1, 'epsilon') #1e-8。 Imagenet: 1.0 or 0.1

tf.flags.DEFINE_integer('batch_size', 2, 'batch大小')
tf.flags.DEFINE_integer('num_epochs', 1, 'epochs')

tf.flags.DEFINE_string('buckets', 'oss://scene2017', '训练图片所在文件夹')
official_model_path = 'oss://scene2017/slim/inception_resnet_v2_2016_08_30.ckpt'

tf.flags.DEFINE_string('checkpointDir', 'oss://scene2017', '模型输出文件夹')
model_path = os.path.join(FLAGS.checkpointDir,'model.ckpt')    # finetune后的
tf.flags.DEFINE_string('writes', 'oss://scene2017/slim/submit.txt', '预测结果的保存')


image_size = inception_resnet_v2.default_image_size #  299
num_labels = 80

'''
鉴于 每次验证要恰好读完整个验证集,而且下次还要重新读,目前在用tfrecord时无法(优雅地)实现,我们control the queue mannually: magic
https://github.com/tensorflow/tensorflow/issues/7951
'''
magic_val_len = 7120 #验证集大小
magic_vac_batch_size = 128 #验证时batch_size可以很大,只要内/显存够

def read_and_decode(tfrecord_file, batch_size, num_epochs):  

    filename_queue = tf.train.string_input_producer([tfrecord_file], num_epochs = num_epochs)  
    reader = tf.TFRecordReader()  
    _, serialized_example = reader.read(filename_queue)  

    img_features = tf.parse_single_example(  
                                        serialized_example,  
                                        features={  
                                               'label': tf.FixedLenFeature([], tf.int64),  
                                               'h': tf.FixedLenFeature([], tf.int64),
                                               'w': tf.FixedLenFeature([], tf.int64),
                                               'c': tf.FixedLenFeature([], tf.int64),
                                               'image': tf.FixedLenFeature([], tf.string),  
                                               })  

    h = tf.cast(img_features['h'], tf.int32)
    w = tf.cast(img_features['w'], tf.int32)
    c = tf.cast(img_features['c'], tf.int32)

    image = tf.decode_raw(img_features['image'], tf.uint8)  
    image = tf.reshape(image, [h, w, c])

    label = tf.cast(img_features['label'],tf.int32) 

    ##########################################################  
    '''data augmentation here'''   
#    distorted_image = tf.random_crop(images, [530, 530, img_channel])
#    distorted_image = tf.image.random_flip_left_right(distorted_image)
#    distorted_image = tf.image.random_brightness(distorted_image, max_delta=63)
#    distorted_image = tf.image.random_contrast(distorted_image, lower=0.2, upper=1.8)

#    image = tf.image.resize_images(image, (image_size,image_size))
#    image = tf.image.per_image_standardization(image)
#    image = tf.reshape(image, [image_size * image_size * 3])
    image = inception_preprocessing.preprocess_image(image, 
                                                     image_size, 
                                                     image_size,
                                                     is_training=True)

    ##########################################################
    '''shuffle here'''
    image_batch, label_batch = tf.train.shuffle_batch([image, label],       
                                                batch_size= batch_size,  
                                                num_threads= 64,    # 注意多线程有可能改变图片顺序
                                                capacity = 10240,
                                                min_after_dequeue= 256
                                                )

    return image_batch, label_batch

def read_and_decode_test(tfrecord_file, batch_size, num_epochs):  

    filename_queue = tf.train.string_input_producer([tfrecord_file], num_epochs = num_epochs)  
    reader = tf.TFRecordReader()  
    _, serialized_example = reader.read(filename_queue)  

    img_features = tf.parse_single_example(  
                                        serialized_example,  
                                        features={  
                                               'label': tf.FixedLenFeature([], tf.int64),  
                                               'h': tf.FixedLenFeature([], tf.int64),
                                               'w': tf.FixedLenFeature([], tf.int64),
                                               'c': tf.FixedLenFeature([], tf.int64),
                                               'image': tf.FixedLenFeature([], tf.string),   #https://stackoverflow.com/questions/41921746/tensorflow-varlenfeature-vs-fixedlenfeature
                                               'image_id': tf.FixedLenFeature([], tf.string)                  
                                               })  

    h = tf.cast(img_features['h'], tf.int32)
    w = tf.cast(img_features['w'], tf.int32)
    c = tf.cast(img_features['c'], tf.int32)
    image_id = img_features['image_id']

    image = tf.decode_raw(img_features['image'], tf.uint8)  
    image = tf.reshape(image, [h, w, c])

    label = tf.cast(img_features['label'],tf.int32) 

    ##########################################################  
    '''no data augmentation'''   
    #image = tf.image.resize_images(image, (image_size,image_size))
#    image = tf.image.per_image_standardization(image)
#    image = tf.reshape(image, [image_size * image_size * 3])
    image = inception_preprocessing.preprocess_image(image, 
                                                     image_size, 
                                                     image_size,
                                                     is_training=False)
    '''
    inception_preprocessing.preprocess_for_eval的bug?
    '''
    image.set_shape([None, None, 3])

    image_batch, label_batch, image_id_batch= tf.train.batch([image, label, image_id],   
                                                batch_size= batch_size,  
                                                num_threads= 64,    # 注意多线程有可能改变图片顺序
                                                capacity = 2000,
                                                allow_smaller_final_batch = True
                                                )  

    return image_batch, label_batch, image_id_batch 

def batch_to_list_of_dicts(indices2, image_id_batch2):
    result = [] #[{"image_id":"a0563eadd9ef79fcc137e1c60be29f2f3c9a65ea.jpg","label_id": [5,18,32]}]
    dict_ = {}
    for item in range(indices2.shape[0]):
        dict_ ['image_id'] = image_id_batch2[item].decode()
        dict_['label_id'] = indices2[item,:].tolist()
        result.append(dict_)
        dict_ = {}
    return result

'''https://github.com/tensorflow/models/blob/master/research/slim/train_image_classifier.py'''
def get_variables_to_train():
    """Returns a list of variables to train.
    Returns:
      A list of variables to train by the optimizer.
    """    
    trainable_scopes = FLAGS.trainable_scopes

    if trainable_scopes == "None":
        print("from scratch")
        return tf.trainable_variables()
    else:
        print("train the specified layer")
        scopes = [scope.strip() for scope in trainable_scopes.split(',')]

    variables_to_train = []
    for scope in scopes:
        variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
        variables_to_train.extend(variables)
#    variables_to_train = [i.name for i in variables_to_train]
    return variables_to_train

def read_tfrecord2(tfrecord_file, batch_size, train_flag, num_epochs, total_steps):

    #因为test有image_id,否则和train共用输入函数就行了。另外read_and_decode训练中会加入data augmentation,因此验证集和测试集均用第二个函数
    if train_flag:
        train_batch, train_label_batch = read_and_decode(tfrecord_file, batch_size, num_epochs)

        with slim.arg_scope(inception_resnet_v2_arg_scope()):
            train_logits, end_points = inception_resnet_v2(train_batch, num_classes = num_labels, is_training = True)
        #Define the scopes that you want to exclude for restoration
        exclude = ['InceptionResnetV2/Logits', 'InceptionResnetV2/AuxLogits']
        variables_to_restore = slim.get_variables_to_restore(exclude = exclude)
        variables_to_train = get_variables_to_train()

        #Performs the equivalent to tf.nn.sparse_softmax_cross_entropy_with_logits but enhanced with checks
        loss = tf.losses.sparse_softmax_cross_entropy(labels=train_label_batch, logits=train_logits)
        #slim.losses.add_loss(pose_loss)
        total_loss = tf.losses.get_total_loss()    #obtain the regularization losses as well

        #http://blog.csdn.net/xierhacker/article/details/53174558
        optimizer = tf.train.AdamOptimizer(
                                            learning_rate=FLAGS.learning_rate,
                                            beta1=FLAGS.beta1,
                                            beta2=FLAGS.beta2,
                                            epsilon=FLAGS.epsilon,
                                            use_locking=False,
                                            name='Adam'
                                          )

        '''要确定训练哪些层需要用这个函数,默认是全部都训练: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/learning.py#L374'''
        train_op = slim.learning.create_train_op(total_loss, optimizer, variables_to_train = variables_to_train)

        '''minibatch accuracy, non-streaming'''
        train_accuracy_batch = tf.reduce_mean(tf.cast(tf.nn.in_top_k(predictions = train_logits, targets=train_label_batch, k=3),tf.float32))
        '''Streaming accuracyy'''    
        train_accuracy, train_accuracy_update= tf.metrics.mean(tf.cast(tf.nn.in_top_k(predictions = train_logits, targets=train_label_batch, k=3),tf.float32))

    else:
        val_test_batch, val_test_label_batch, image_id_batch= read_and_decode_test(tfrecord_file, batch_size, num_epochs) 

        with slim.arg_scope(inception_resnet_v2_arg_scope()):
            val_test_logits, end_points = inception_resnet_v2(val_test_batch, num_classes = num_labels, is_training = False)

        '''Useless minibatch accuracy, non-streaming'''
        val_test_accuracy_batch = tf.reduce_mean(tf.cast(tf.nn.in_top_k(predictions = val_test_logits, targets=val_test_label_batch, k=3),tf.float32))
        '''Streaming accuracyy'''    
        val_test_accuracy, val_test_accuracy_update= tf.metrics.mean(tf.cast(tf.nn.in_top_k(predictions = val_test_logits, targets=val_test_label_batch, k=3),tf.float32))

        values, indices = tf.nn.top_k(val_test_logits, 3)

    saver = tf.train.Saver() # 生成saver
    if train_flag:
        if FLAGS.use_official:
            saver_step0 = tf.train.Saver(variables_to_restore)
        else:
            saver_step0 = tf.train.Saver()

    with tf.Session() as sess:
        # https://github.com/tensorflow/tensorflow/issues/1045
        sess.run(tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()))
        print("Initialized")

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        if train_flag:

            '''
            类数被修改的最后一层logits是如何初始化的,是sess.run(tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()))???
            '''
            if FLAGS.use_official:
                saver_step0.restore(sess, official_model_path)
            else:
                saver_step0.restore(sess, model_path)

            try:
                step = 0
                start_time = time.time()
                while not coord.should_stop():
                    _, l, logits2, train_acc2_batch, train_acc2, train_acc2_update = sess.run([train_op, total_loss, train_logits, train_accuracy_batch, train_accuracy, train_accuracy_update])

                    duration = time.time() - start_time

                    if (step % 10 == 0):
                        print("Minibatch loss at step %d - %d: %.6f (%.3f sec)" % (step, total_steps, l, duration))
                        print("Minibatch accuracy: %.6f" % train_acc2)
                        print("lr: %.6f" % optimizer._lr) #https://stackoverflow.com/questions/38882593/learning-rate-doesnt-change-for-adamoptimizer-in-tensorflow
                    #if (step % 100 == 0):
                    #Validating accuracy

                    step += 1
            except tf.errors.OutOfRangeError:
                print('Done training for %d epochs, %d steps.' % (num_epochs, step))
                print('FInal training accuracy: %.6f' % (train_acc2_update))
                #Final Validating accuracy

                saver.save(sess, model_path)
            finally:        
                coord.request_stop()

        else: 
            saver.restore(sess, model_path) #会将已经保存的变量值resotre到 变量中。
            results = []
            try:
                step = 0
                start_time = time.time()
                while not coord.should_stop():
                    val_test_logits2, val_test_acc2_batch, val_test_acc2, val_test_acc2_update,image_id_batch2, indices2, values2= sess.run([val_test_logits, val_test_accuracy_batch, val_test_accuracy, val_test_accuracy_update, image_id_batch, indices, values])
                    step += 1

                    results += batch_to_list_of_dicts(indices2, image_id_batch2)
                    if (step % 10 == 0):
                        print('Useless minibatch testing accuracy at step %d: %.6f' % (step, val_test_acc2_batch))
                        print(indices2.shape[0])

            except tf.errors.OutOfRangeError:
                print('Done testing in, %d steps.' % (step))
                print('FInal Testing accuracy: %.6f' % (val_test_acc2_update))


                '''Writing JSON data'''
                #results = [{"image_id":"a0563eadd9ef79fcc137e1c60be29f2f3c9a65ea.jpg","label_id": [5,18,32]}]
                print(len(results))
                tf.gfile.GFile(FLAGS.writes, 'w').write(str(results)) # PAI的坑
                #with open('oss://scene2017.oss-cn-shanghai-internal.aliyuncs.com/softmax/submit.json', 'w') as f:
                        # json.dump(results, f)
            finally:        
                coord.request_stop()

        coord.join(threads)

def main(_):

    train_flag = FLAGS.train_flag


    if train_flag:
        tfrecord_file = os.path.join(FLAGS.buckets,'train.tfrecord') 
#'../ai_challenger_scene_train_20170904/train.tfrecord'
#    tfrecord_file_val = '../ai_challenger_scene_train_20170904/val.tfrecord' # validate while training
        batch_size = FLAGS.batch_size#256
        num_epochs = FLAGS.num_epochs
        total_steps = 1.0 * num_epochs * 53879 / batch_size
        print("total_steps is %d" % total_steps)
        print("num_epochs is %d" % num_epochs)
        print("batch_size is %d" % batch_size)
        print("lr %.6f" % FLAGS.learning_rate)
        read_tfrecord2(tfrecord_file, batch_size, train_flag, num_epochs, total_steps)
    else:
        tfrecord_file = os.path.join(FLAGS.buckets,FLAGS.val_test)#'../ai_challenger_scene_train_20170904/val.tfrecord'  #test
        batch_size = FLAGS.batch_size #16 
        num_epochs = FLAGS.num_epochs #1
        total_steps = 1.0 * num_epochs * 7120 / batch_size  #7120是val.tfrecord的,其他的test稍微有点误差,不管了
        print("total_steps is %d" % total_steps)
        read_tfrecord2(tfrecord_file, batch_size, train_flag, num_epochs, total_steps)

# 53879  7120   7040


if __name__ == "__main__": #使用这种方式保证了,如果此文件被其它文件import的时候,不会执行main中的代码
    tf.app.run() #解析命令行参数,调用main函数 main(sys.argv)

你可能感兴趣的:(TensorFlow)