tensorflow使用tf.estimator限制gpu显存

tf.estimator是tensorflow的高阶api,使用下面代码可以实现限制显存,0.8代表使用80%的显存。

session_config = tf.ConfigProto(log_device_placement=True,allow_soft_placement=True)
session_config.gpu_options.per_process_gpu_memory_fraction = 0.8
run_config = tf.estimator.RunConfig(
      session_config=session_config,
      model_dir=FLAGS.output_dir,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps)

之前搜到另一种写法,亲测这种写法在replace之后,前面设置的model_ dir和save_checkpoints_steps均失效了,很尴尬~

run_config = tf.estimator.RunConfig(
      model_dir=FLAGS.output_dir,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps)
session_config = tf.ConfigProto(log_device_placement=True,allow_soft_placement=True)
session_config.gpu_options.per_process_gpu_memory_fraction = 0.8
run_config = tf.estimator.RunConfig().replace(session_config=session_config)

所以还是使用第一种写法吧。

你可能感兴趣的:(深度学习)