tensorflow学习之Faster R-CNN模型的保存与加载

   在近期浏览论文的过程中发现近期新出的论文中,很少在caffe框架上进行实验验证,随之而来的是tensorflow、MXNet、PyTouch这些深度学习框架,为了跟踪前沿技术,作为学生的笔者无奈只能从原始的caffe框架使用转向了tensorflow。选择tensorflow作为新的框架来进行学习主要是因为其具有以下特点:

1. 可用性
    TensorFlow 工作流程相对容易,API 稳定,兼容性好,并且 TensorFlow 与 Numpy 完美结合,这使大多数精通 Python 数据科学家很容易上手。与其他一些库不同,TensorFlow 不需要任 何编译时间, 这允许你可以更快地迭代想法。在TensorFlow 之上 已经建立了多个高级 API,例如Keras 和 SkFlow,这给用户使用TensorFlow 带来了极大的好处。
2. 灵活性
    TensorFlow 能够在各种类型的机器上运行,从超级计算机到嵌入式系统。它的分布式架构使大量数据集的模型训练不需要太多的时 间。TensorFlow 可以同时在多个 CPU,GPU 或者两者混合运行。
3. 效率
    自 TensorFlow 第一次发布以来,开发团队花费了大量的时间和努力 来改进TensorFlow 的大部分的实现代码。 随着越来越多的开发人 员努力,TensorFlow 的效率不断提高。
4. 支持
    TensorFlow 由谷歌提供支持,谷歌投入了大量精力开发 TensorFlow,它希望 TensorFlow 成为机器学习研究人员和开发人员的通用语言。此外,谷歌在自己的日常工作中也使用 TensorFlow,并且持续对其提供支持,在 TensorFlow 周围形成了 一个强大的社区。谷歌已经在 TensorFlow 上发布了多个预先训练好的机器学习模型,他们可以自由使用。

    但是在放弃caffe从头开始学习tensorflow时还是踩了不少的坑,首先就是tensorflow框架训练完成的模型如何使用的问题。如果不能解决训练完成的模型如何使用的问题,那就没办法验证自己的想法的正确性,对于训练也就没有什么意义了。关于模型的存储和训练完成模型的使用,我查了很多现在出版的书籍,以及网上的一些信息,但是普遍的情况都是一些简单的保存一两个变量,例如:
    模型参数的保存

    import tensorflow as tf
    w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
    w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
    saver = tf.train.Saver()
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    saver.save(sess, 'my_test_model')

    模型参数的加载

    '''restore tensor from model'''
    w_out= self.graph.get_tensor_by_name('W:0')
    b_out = self.graph.get_tensor_by_name('b:0')
    _input = self.graph.get_tensor_by_name('x:0')
    _out = self.graph.get_tensor_by_name('y:0')
    y_pre_cls = self.graph.get_tensor_by_name('output:0')

    这样简单的数据保存学习完之后完全不能满足我们在做深度网络模型训练和训练完成模型的使用方面的任务,于是无奈只能在从官网上查找一些信息来一探究竟了。tensorflow模型训练完成之后将会生成4个文件(其中有3个文件是每个snapshot都会生成,一个文件记录的是每个snapshot模型存储的路径信息),其分别为:

  • my_test_model.data-00000-of-00001
  • my_test_model.index
  • my_test_model.meta
  • checkpoint

这4个文件的作用分别为:

  • my_test_model.data-00000-of-00001    #文件包含训练变量
  • my_test_model.index         #存储训练模型保存时的索引信息
  • my_test_model.meta          #存储定义的tensorflow图结构
  • checkpoint          #存储模型保存的路径信息,也是模型加载的重要文件之一

checkpiont文件会在训练阶段会随着其余三个文件的新增而自动添加相关信息到该文件中,其格式为:

model_checkpoint_path: "/home/smart/Faster-RCNN_TF_ships_dataset/output/faster_rcnn_end2end/voc_2007_trainval/VGGnet_fast_rcnn_iter_70000.ckpt"
all_model_checkpoint_paths: "/home/smart/Faster-RCNN_TF_ships_dataset/output/faster_rcnn_end2end/voc_2007_trainval/VGGnet_fast_rcnn_iter_5000.ckpt"
all_model_checkpoint_paths: "/home/smart/Faster-RCNN_TF_ships_dataset/output/faster_rcnn_end2end/voc_2007_trainval/VGGnet_fast_rcnn_iter_10000.ckpt"
all_model_checkpoint_paths: "/home/smart/Faster-RCNN_TF_ships_dataset/output/faster_rcnn_end2end/voc_2007_trainval/VGGnet_fast_rcnn_iter_15000.ckpt"
all_model_checkpoint_paths: "/home/smart/Faster-RCNN_TF_ships_dataset/output/faster_rcnn_end2end/voc_2007_trainval/VGGnet_fast_rcnn_iter_20000.ckpt"
all_model_checkpoint_paths: "/home/smart/Faster-RCNN_TF_ships_dataset/output/faster_rcnn_end2end/voc_2007_trainval/VGGnet_fast_rcnn_iter_25000.ckpt"
all_model_checkpoint_paths: "/home/smart/Faster-RCNN_TF_ships_dataset/output/faster_rcnn_end2end/voc_2007_trainval/VGGnet_fast_rcnn_iter_30000.ckpt"
all_model_checkpoint_paths: "/home/smart/Faster-RCNN_TF_ships_dataset/output/faster_rcnn_end2end/voc_2007_trainval/VGGnet_fast_rcnn_iter_35000.ckpt"
all_model_checkpoint_paths: "/home/smart/Faster-RCNN_TF_ships_dataset/output/faster_rcnn_end2end/voc_2007_trainval/VGGnet_fast_rcnn_iter_40000.ckpt"
all_model_checkpoint_paths: "/home/smart/Faster-RCNN_TF_ships_dataset/output/faster_rcnn_end2end/voc_2007_trainval/VGGnet_fast_rcnn_iter_45000.ckpt"
all_model_checkpoint_paths: "/home/smart/Faster-RCNN_TF_ships_dataset/output/faster_rcnn_end2end/voc_2007_trainval/VGGnet_fast_rcnn_iter_50000.ckpt"
all_model_checkpoint_paths: "/home/smart/Faster-RCNN_TF_ships_dataset/output/faster_rcnn_end2end/voc_2007_trainval/VGGnet_fast_rcnn_iter_55000.ckpt"
all_model_checkpoint_paths: "/home/smart/Faster-RCNN_TF_ships_dataset/output/faster_rcnn_end2end/voc_2007_trainval/VGGnet_fast_rcnn_iter_60000.ckpt"
all_model_checkpoint_paths: "/home/smart/Faster-RCNN_TF_ships_dataset/output/faster_rcnn_end2end/voc_2007_trainval/VGGnet_fast_rcnn_iter_65000.ckpt"
all_model_checkpoint_paths: "/home/smart/Faster-RCNN_TF_ships_dataset/output/faster_rcnn_end2end/voc_2007_trainval/VGGnet_fast_rcnn_iter_70000.ckpt"

    深度网络模型的加载

    在我们训练完成之后,会在指定的文件夹下生成如下图所示的文件:

tensorflow学习之Faster R-CNN模型的保存与加载_第1张图片

    当我们需要将训练完成的模型用于测试或者运行demo时,我们只需要将某一个snapshot生成的3个*.ckpt.*文件和checkpoint文件全部复制到我们需要使用模型的所在目录即可。使用如下指令即可将模型重新加载,完成模型的继续训练或者后续的验证操作。

    saver = tf.train.Saver()
    saver.restore(sess, 'model_path')

    但是在模型加载时,路径的写法将会决定我们模型的加载时候能够正常进行,当我们将之前所说的4个文件复制到指定的文件夹时,例如:“/home/smart/Faster-RCNN_TF_ships_dataset/data/model”文件夹下,我们需要使用的是VGGnet_fast_rcnn_iter_70000这个模型,则在model_path参数应该设置为:“/home/smart/Faster-RCNN_TF_ships_dataset/data/model/VGGnet_fast_rcnn_iter_70000.ckpt”这样设置完成之后,我们的模型就可以正常加载了。

    注意!!! 笔者之前曾将在github上下载了开源的源码以及训练好的模型,发现其中只有一个文件也就是VGGnet_fast_rcnn_iter_70000.ckpt这个文件,我曾试图将tensorflow自动生成的*.ckpt*文件中的任意一个的后缀名去掉使其跟从他人那里下载下来的模型在形式上一致,也就是将VGGnet_fast_rcnn_iter_70000.ckpt.index改为VGGnet_fast_rcnn_iter_70000.ckpt,然后结果发现模型加载不能成功,代码提示内部的数据格式不能满足要求。

    最后就贴一下模型训练的效果吧,也算是一段探索的里程碑!哈哈….
tensorflow学习之Faster R-CNN模型的保存与加载_第2张图片

你可能感兴趣的:(深度学习工程)