读别人的代码的时候经常看到这几个函数:
flags = tf.flags
FLAGS = flags.FLAGS
flags.DEFINE_xxx()
FLAGS.parse_flags()
这是tensorflow的命令行参数。在深度学习训练中,我们常常需要动态的配置诸如batch size、learning rate、epoch、kernel size等等超参数,同时在分布式训练时为了区别运行不同的代码,我们也需要配置一个参数用以运行不同代码。那么有无一种比较合适的可以动态配置的方法呢?答案是肯定的,一种是使用python的argparse库,另外一种是使用tensorflow的tf.app.flags组件。
tf.app.flags组件的作用:
这种写法用于帮助我们添加命令行的可选参数。也就是说利用该函数我们可以实现在命令行中选择需要设定的参数来运行程序, 可以不用反复修改源代码中的参数,直接在命令行中进行参数的设定。
即通过在命名行输入不同的文件名、参数,可以快速完成程序的调参和更换训练集的操作,不需要进入源码中更改。相当于对Python中的命令行参数模块optpars(参考: python中处理命令行参数的模块optpars )做了一层封装。
下面具体讲解一下各语句:
一、flags = tf.flags
flags.DEFINE_xxx()
tf.app.flags其实是tensorflow定义的一个类,本质上是基于argparse再封装更友好直观的一个类库,其主要源代码位于“tensorflow/tensorflow/python/platform/flags.py”,flags.py定义了如何构造和解析命令行参数。
optpars中的参数类型是通过参数 “type=xxx” 定义的,tf中每个合法类型都有对应的 “DEFINE_xxx”函数。常用:
“DEFINE_xxx”函数带3个参数,分别是变量名称,默认值,用法描述,例如:
tf.app.flags.DEFINE_string('ckpt_path', 'model/model.ckpt-100000', '''Checkpoint directory to restore''')
该语句定义一个名称是 "ckpt_path" 的变量,默认值是 ckpt_path = 'model/model.ckpt-100000',描述信息表明这是一个用于保存节点信息的路径。
---------------------
引自:https://blog.csdn.net/dcrmg/article/details/79658725
实际上他是一个类名,我们看到源代码有如下定义:
# Provides the global object that can be used to access flags.
FLAGS = _FlagValues()
通过这个类别名,我们可以获取到该类定义的属性和成员。
def mark_flag_as_required(flag_name): #强制要求显示定义一个命令行参数
def mark_flags_as_required(flag_names): #强制要求显示定义N个命令行参数,它是通过调用前者来实现的
这两个函数的作用是一样的,都是强制要求显示定义命令行参数,否则会报错!前者可以强制要求显示定义一个命令行参数,而后者则可以强制要求显示定义N个命令行参数,它是通过调用前者来实现的。如果你用这两个接口来检查命令行参数,但你code里却没有定义,那么运行时就一定会报错!
通常按照如下步骤即可(以下默认都执行了“import tensorflow as tf”):
第一步(获取可以class _FlagValues(object)成员函数的使用句柄):
flags = tf.app.flags
第二步(不是必须的,只是为了更安全的检查):
flags.mark_flags_as_required(flag_names)
或者
flags.mark_flag_as_required(flag_name)
第三步(定义不同类型的命令行参数):
flags.DEFINE_string(flag_name, default_value, docstring)
flags.DEFINE_integer(flag_name, default_value, docstring)
flags.DEFINE_boolean(flag_name, default_value, docstring)
flags.DEFINE_float(flag_name, default_value, docstring)
至此,tensorflow的命令行参数就构造完成了!
解析tensorflow命令行参数更加简单,但是命令行参数名字一定要注意与构造时的一致,否则就会报错没有意义了,通常按照如下步骤解析
第一步(获取可以使用class _FlagValues(object)成员属性的使用句柄):
FLAGS = flags.FLAGS
第二步(直接获取已定义了的命令行参数,不同参数类型均可以按照如下方式获取):
flag_name = FLAGS.flag_name
至此,tensorflow的命令行参数就解析完成了!
FLAGS._parse_flags()
- FLAGS = tf.flags.FLAGS #FLAGS保存命令行参数的数据
- FLAGS._parse_flags() #将其解析成字典存储到FLAGS.__flags中
如下是一个tensorflow命令行参数构造和解析的使用例子,其命令行参数在这里没有实际意义,我们只是为了举例说明,以便观察其运行结果并加深对tensorflow命令行参数原理的理解!
import tensorflow as tf
flags = tf.app.flags # structure first
flags.mark_flags_as_required(['batch_size', 'learning_rate', 'data_dir']) # structure second(可加可不加)
# structure third
#flags.DEFINE_integer("batch_size", 1000, "training batch size")
flags.DEFINE_float("learning_rate", 0.1, "training learning rate")
flags.DEFINE_string("data_dir", "/home/xsr-ai/study", "training data directory")
flags.DEFINE_boolean("use_gpu_device", "False", "use gpu device to training or not")
# parsing first
FLAGS = flags.FLAGS
# parsing second
batch_size = FLAGS.batch_size
learning_rate = FLAGS.learning_rate
data_dir = FLAGS.data_dir
use_gpu_device = FLAGS.use_gpu_device
# finally validation
print("batch_size=%s" % batch_size)
print("learning_rate=%f" % learning_rate)
print("data_dir=%s" % data_dir)
print("use_gpu_device=%d" % use_gpu_device)
我们将这段测试代码复制粘贴到文档testflags.py里,然后执行“python testflags.py”出现了如下错误:
这是因为这行代码被注释了,将如下这行代码的注释去掉:
#flags.DEFINE_integer("batch_size", 1000, "training batch size")
再运行一次代码,结果如下:
结果都按照设定的命令行参数默认值输出了,结果没错!
那为什么第一次运行却会报错呢?原因是mark_flag_as_required,他会进行安全检查确保被他标志了的命令行参数都有显示定义!
下面我们故意在命令行里写错参数名字,看一下会有什么结果:
可以看到“learning_rate”的输出值不是按照命令行里配置的0.01,而是保持了默认值0.1,原因是命令行参数名字写错成“learning_rt”,但也没有报错,这是命令行参数名字写错时的情况,tf.app.flags与argparse的行为表现是一致的!
最后,我们直接在命令行里正确配置参数,看一下运行结果,如下,都正确无误,太棒了!
还有一种写法是:
python文件命名为flags_test.py
# -*- coding=utf-8 -*-
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('ckpt_path', 'model/model.ckpt-100000', '''模型保存路径''')
tf.app.flags.DEFINE_float('learning_rate',0.0001,'''初始学习率''')
tf.app.flags.DEFINE_integer('train_steps', 50000, '''总的训练轮数''')
tf.app.flags.DEFINE_boolean('is_use_gpu', False, '''是否使用GPU''')
print '模型保存路径: {}'.format(FLAGS.ckpt_path)
print '初始学习率: {}'.format(FLAGS.learning_rate)
print '总的训练次数: {}'.format(FLAGS.train_steps)
print '是否使用GPU: {}'.format(FLAGS.is_use_gpu)
按默认设置执行程序:
参考资料:
tensorflow命令行参数原理详细解析:https://zhuanlan.zhihu.com/p/31380345
tensorflow命令行参数:tf.app.flags.DEFINE_string:https://blog.csdn.net/dcrmg/article/details/79658725