【BUG】tensorflow预训练简单模型及权重文件复用初始化复杂模型

笔者在学习YOLO网络的过程中,遇到了预训练问题。我在网上搜到的大部分相关问题都是在说如何利用之前的预训练权重文件做fineturning


问题:如何预训练简单网络,然后复用权重文件初始化复杂网络?


在yolo中,我需要预训练前20层网络,使其能在物体分类上达到不错的准确率

然后复用这个简单网络的权重文件初始化正式的YOLO网络,这是我遇到的实际问题

 

一开始,我想问题的突破口有三个:

1、npy文件保存与读取

      本来我是用ckpt文件作为权重文件的载体,但是因为不太了解ckpt内部结构,结合网上一些介绍,就想着能不能另辟蹊径呢?

      然而我想的是先把手上的资源利用起来,所以有了2和3

2、tf.get_variable()

      受Salvador Dali的启发,要保留权重文件不过就是要保存变量,既要保存变量就得从变量的载体下手,要了解这个函数有哪些参数,返回值是什么,还有其他哪些类似功能的函数

3、tf.train.Saver()
      保留权重文件的另一个方向就是在如何保留入手,这个应该是要和variable变量函数结合起来操作的吧,但一开始我并不知道如何下手

      所以还是到官网,看文档,了解这函数的参数列表,返回值等等,再结合相关的博文,自己做一点测试。

 

 

利用 tf.get_variable()和 tf.train.Saver()解决上述问题

首先了解下tf.get_variable

【BUG】tensorflow预训练简单模型及权重文件复用初始化复杂模型_第1张图片

 

其中collections这个参数有什么作用呢?

可以清楚地看到所有variable都会默认为global

这个参数可以将特定的variable设为local

collections=[tf.GraphKeys.LOCAL_VARIABLES]

 

为什么要这么做呢?接下来看完tf.train.Saver()的介绍,你就清楚了

 

【BUG】tensorflow预训练简单模型及权重文件复用初始化复杂模型_第2张图片

Saver的参数var_list可以选择一个列表的变量,这些个变量的op name会在checkpoint files中作为keys

这说明什么?说明Saver可以任意选择要保存global变量或是local变量,或者两者都保存呢

而这个var_list可以是什么呢?

tf.global_variables()和tf.local_variables()都可以返回一个列表的global或local变量呢,两个加起来不就可以随意保存我想要的变量了吗

 

到这里是不是有思路了?

我可以将简单网络的变量,全部设为global,然后将用save(global)保存下来

然后restore回复杂网络不就可以了?

 

那么问题来了,真的有那么简单吗?

复杂网络的变量是全部设为global,还是部分设为global部分设为local,哪一部分设为local呢

让我做个测试验证一下吧

我要验证的是

当我把一部分global变量保存下来之后,在restore之前加入另一部分glocal变量,这个restore会失败吗?

或者在restore之前加入另一部分local变量,这个restore会成功吗?

 

下面是我用mnist数据集做的一个小测试

 

import tensorflow as tf

import os

output='output'

output_dir=os.path.join(os.path.abspath(output), 'weights')

ckpt_file = os.path.join(output_dir, 'save.ckpt')

from tensorflow.examples.tutorials.mnist import input_data

mnist=input_data.read_data_sets("MNIST_data/",one_hot=True)

sess=tf.InteractiveSession()

x=tf.placeholder(tf.float32,[None,784])

W=tf.get_variable("W",initializer=tf.zeros([784,10]),collections=[tf.GraphKeys.GLOBAL_VARIABLES])

b=tf.get_variable("b",initializer=tf.zeros([10]),collections=[tf.GraphKeys.GLOBAL_VARIABLES])


#b1=tf.get_variable("b1",initializer=tf.zeros([10]),collections=[tf.GraphKeys.LOCAL_VARIABLES])
    

y=tf.nn.softmax(tf.matmul(x,W)+b)

y_=tf.placeholder(tf.float32,[None,10])

cross_entropy=tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y),reduction_indices=[1]))

train_step=tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

tf.global_variables_initializer().run()
tf.local_variables_initializer().run()

variable_to_restore = tf.global_variables()#+tf.local_variables()

saver = tf.train.Saver(variable_to_restore, max_to_keep=None)

is_train=False#True就训练,False为检测
with tf.variable_scope("weights",reuse=True):

    if is_train:
        
        for i in range(1000):
            
            batch_xs,batch_ys=mnist.train.next_batch(100)
            
            train_step.run({x:batch_xs,y_:batch_ys})
            
            if i % 100 ==0:
                
                print('Saving checkpoint file to: {}'.format(output_dir))
                
                saver.save(sess,ckpt_file)
    
    
        correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
        
        accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
        
        print(accuracy.eval({x:mnist.test.images,y_:mnist.test.labels}))
    else:
        model_file=tf.train.latest_checkpoint(output_dir)
        saver.restore(sess,model_file)
        print(sess.run(W))#这里是把之前保存的变量取出来观察一下
        print(sess.run(b))
        correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
        
        accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
        
        print(accuracy.eval({x:mnist.test.images,y_:mnist.test.labels}))

 

结果是:

 

当我把GLOBAL的W和b保存下来之后,加了一个b1的local变量,再restore,没有报错

设is_train=False,测试结果也正常

当我把GLOBAL的W和b保存下来之后,加了一个b1的global变量,(collections=[tf.GraphKeys.GLOBAL_VARIABLES])

出现了一下错误

 

 

OK,至此,问题不就解决了吗?

先将简单网络的变量都设为global然后save下来

然后将复杂网络中较之简单网络多出来的变量统统设为local

然后用初始化为Saver(global)的对象1去restore,就可以达到利用简单网络额权重文件初始化复杂网络的目的啦


最后用初始化为Saver(global+local)的对象2去save,就可以把复杂网络的权重文件保存下来啦

 


 

 

 

 

 

 

你可能感兴趣的:(BUG)