python基础--absl.flags

之前在tensorflow的mnist例程中看到了使用 absl.flags的方法来载入和解析参数的,出于学习的目的,就自己试验了一下,

代码如下:

 1 # *_*coding:utf-8 *_*
 2 # athor:auto
 3 
 4 import sys, os
 5 from absl import app
 6 from absl import flags
 7 from official.utils.flags import core as flags_core
 8 
 9 
10 FLAGS = flags.FLAGS
11 flags.DEFINE_string('gpu', None, 'comma separated list of GPU to use.')
12 
13 
14 def flagtest(argv):
15     del argv
16     if FLAGS.gpu:
17         print("gpu is %s" % FLAGS.gpu)
18         os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu
19     else:
20         print('Please assign GPUs.')
21         exit()
22 
23 def main(argv):
24     flags_core.define_base()
25     flags_core.define_performance(num_parallel_calls=False)
26     flags_core.define_image()
27     flags.adopt_module_key_flags(flags_core)
28 
29 if __name__ == '__main__':
30     app.run(flagtest)
View Code

其中main中的几个调用都是源自于tensorflow的model/official,里面的函数大多是model/official/utils/flags/core.py内定义好的一些默认参数。
在mnist例子中还可以这样添加自定义项:

  flags_core.set_defaults(data_dir='./tmp/mnist_data',
                          model_dir='./tmp/mnist_model',
                          batch_size=100,
                          train_epochs=40,
                          stop_threshold=0.998)

  

 

参考:

https://blog.csdn.net/faith_binyang/article/details/80551941

转载于:https://www.cnblogs.com/IGNB/p/10616445.html

你可能感兴趣的:(python基础--absl.flags)