slim.arg_scope原理分析

slim.arg_scope的实现使用了修饰器和上下文管理器. 弄清楚这两个语法才能看明白它的源码.


使用方式

slim.arg_scope常用于为tensorflow里的layer函数提供默认值以使构建模型的代码更加紧凑苗条(slim):

import tensorflow as tf
slim = tf.contrib.slim
with slim.arg_scope(
        [layers.conv2d], 
        padding='SAME',         
        initializer= xavier_initializer(),
        regularizer= l2_regularizer(0.05)):  
    net = slim.conv2d(net, 256, [5, 5])

并不是所有的方法都能用arg_scope设置默认参数, 只有用@slim.add_arg_scope修饰过的方法才能使用arg_scope. 例如示例中的conv2d方法, 它就是被修饰过的(见源码).
所以, 要使slim.arg_scope正常运行起来, 需要两个步骤:

  1. @add_arg_scope修饰目标函数
  2. with arg_scope(...) 设置默认参数.

下面通过它的源码来分析一下它具体是如何工作的.

关键数据结构

代码:

_ARGSTACK = [{}]
_DECORATED_OPS = {}

_ARGSTACK被当作一个栈来使用, 用于存储所有的scope. 一个scope就是一个dict. 它的key是str(fn), 其中fn代表目标函数;value是一个字典, 存储着用户调用slim.arg_scope时传入的参数. 之所以用栈, 是因为with是可以嵌套使用的. 栈顶存储的是最内层/最近的scope, 即当前scope.
例如,

@slim.add_arg_scope
def fn(a, b, c=3):
    d = c + b
    print("a={}, b={}".format(a, b))
    return d

with slim.arg_scope([fn], a = 1):
    fn(b = 2)

在上面的这段代码中, 执行到fn(b=2)时, _ARGSTACK的值为:

[{}, {'': {'a': 1}}]

_DECORATED_OPS是一个dict, 用于存储已经被@add_arg_scope修饰过的函数信息. 它的key也是str(fn), value的生成机制却很难get到它的真实意图, 给我的感觉一是有错, 二是无用. 所以, 把_DECORATED_OPS换成一个set应该也是可以的.

仍然是上面那段代码, _DECORATED_OPS的值为:

{'': ()}

add_arg_scope

只保留了关键的代码:

def add_arg_scope(func):
  """Decorates a function with args so it can be used within an arg_scope.
  ...
  """
  @functools.wraps(func)
  def func_with_args(*args, **kwargs):
      ...
  _add_op(func) # 在这里将函数放入`_DECORATED_OPS`里.
  ...
  return func_with_args

_add_op的操作真的很简单, 看看_DECORATED_OPS里有没有目标函数, 没有就放进去:

def _add_op(op):
  key_op = _key_op(op)
  if key_op not in _DECORATED_OPS:
    _DECORATED_OPS[key_op] = _kwarg_names(op)
def _key_op(op):
  return getattr(op, '_key_op', str(op))

func_with_args内容暂时略过, 后面会分析.

with arg_scope([], ...)

也是只保留了最关键的代码:

@contextlib.contextmanager
def arg_scope(list_ops_or_scope, **kwargs):
  """Stores the default arguments for the given set of list_ops.
    ...
  """
  ...
    try:
      current_scope = _current_arg_scope().copy()
      ...
        if key_op in current_scope:
          # 如果对一个函数多次应用arg_scope, 则更新已经存储的参数.
          current_kwargs = current_scope[key_op].copy()
          current_kwargs.update(kwargs)
          current_scope[key_op] = current_kwargs
        else:
          # 如果是第一次对这个函数调用arg_scope, 则加进去.
          current_scope[key_op] = kwargs.copy()
      # 将当前scope放入栈顶.
      _get_arg_stack().append(current_scope)
      yield current_scope
    finally:
      _get_arg_stack().pop()

def _get_arg_stack():
  if _ARGSTACK:
    return _ARGSTACK
  else:
    _ARGSTACK.append({})
    return _ARGSTACK

with代码块内部调用目标函数

之前省略的函数func_with_args其实是关键中的关键: 在with代码块内调用被@add_arg_scope修饰的函数, 其实是调用了它, 而它负责从当前栈中查找用户设置的默认参数, 加上用户传入的调用参数, 调用真正的源函数.
func_with_args的代码:

  def func_with_args(*args, **kwargs):
    current_scope = _current_arg_scope()# 当前栈
    current_args = kwargs # 用户传入的关键字参数
    key_func = _key_op(func)  
    if key_func in current_scope:
      # 通过arg_scope设置的默认参数
      current_args = current_scope[key_func].copy() 
      current_args.update(kwargs)
    # 调用原函数
    return func(*args, **current_args)

退出with代码块

弹出栈顶的scope. 代码在arg_scope方法最后一行.

你可能感兴趣的:(tensorflow)