Keras限制GPU显存使用

深度学习门槛越来越低,尤其是Keras这样的高层次API加入以后,简单几行代码就能构建网络并得到不错的效果。

最近工作需要,开始使用Keras写3DConv做3D数据分类。

数据是肺部CT,目标是对检测网络得到的结果使用3D Conv网络做2分类。2D的检测网络获得肺部结节的ROI,但是由于二维上结节目标与肺部正常组织结构(如血管,气管等)特征类似,所以检测结果中包含了大量“假阳”,所以在检测网络的结果基础上使用3D网络,实际测试对“假阳”有很好抑制效果。

由于检测网络使用caffe框架,而3D网络又用以TensorFlow作为后端的Keras编写,结果TensorFlow跟caffe不兼容,同一个Python进程中不能同时包含这两个框架的环境。

于是将两部分代码分别放到不同的显卡上。TensorFlow如果不加限制会占满服务器上的所有显存,需要加以限制,查看了GitHub上的源码和解决方案如下:

首先是对keras配置文件进行修改,如果是在Linux环境下,也就是修改:

~/.keras/keras.json

这个文件。

修改内容如下:

{
    "epsilon": 1e-07,
    "floatx": "float32",
    "image_data_format": "channels_last",
    "backend": "tensorflow",
    "gpu_options": {
        "allow_growth": false,
        "per_process_gpu_memory_fraction": 0.5,
        "visible_device_list": "1"
    }
}
其中
per_process_gpu_memory_fraction
表示使用显存的比例;

visible_device_list
表示使用第几块显卡,0开始;

然后就需要修改后端代码:

在python库路径中找到keras位置,Linux下一般也就是:/usr/local/lib/python2.7/dist-packages/keras-2.****

这样一个路径,修改TensorFlow后端代码:

def get_session():
    """Returns the TF session to be used by the backend.

    If a default TensorFlow session is available, we will return it.

    Else, we will return the global Keras session.

    If no global Keras session exists at this point:
    we will create a new global session.

    Note that you can manually set the global session
    via `K.set_session(sess)`.

    # Returns
        A TensorFlow session.
    """
    global _SESSION
    if tf.get_default_session() is not None:
        session = tf.get_default_session()
    else:
        if _SESSION is None:
            _keras_base_dir = os.path.expanduser('~')
            if not os.access(_keras_base_dir, os.W_OK):
                 _keras_base_dir = '/tmp'

            _keras_dir = os.path.join(_keras_base_dir, '.keras')
            _config_path = os.path.expanduser(os.path.join(_keras_dir,
                                                            'keras.json'))

            if os.path.exists(_config_path):
                try:
                    _config = json.load(open(_config_path))
                except ValueError:
                    _config = {}
            _options = _config.get('gpu_options', None)
            _allow_growth = _options.get('allow_growth', False)
            _mem_frac = _options.get('per_process_gpu_memory_fraction', 1.0)
            _visible_device_list = _options.get('visible_device_list', None)
            _gpu_options = tf.GPUOptions(allow_growth=_allow_growth,
                                         per_process_gpu_memory_fraction=_mem_frac,
                                         visible_device_list=_visible_device_list)

            if not os.environ.get('OMP_NUM_THREADS'):
                config = tf.ConfigProto(allow_soft_placement=True,
                                        gpu_options=_gpu_options)
            else:
                num_thread = int(os.environ.get('OMP_NUM_THREADS'))
                config = tf.ConfigProto(intra_op_parallelism_threads=num_thread,
                                        allow_soft_placement=True,
                                        gpu_options=_gpu_options)
            _SESSION = tf.Session(config=config)
        session = _SESSION
    if not _MANUAL_VAR_INIT:
        with session.graph.as_default():
            _initialize_variables()
    return session

需要额外import 一个json库用来解析keras.json配置文件。

其实也就是对TensorFlow的session属性进行一下初始化。



你可能感兴趣的:(Keras)