slim.arg_scope
的实现使用了修饰器和上下文管理器. 弄清楚这两个语法才能看明白它的源码.
slim.arg_scope
常用于为tensorflow里的layer函数提供默认值以使构建模型的代码更加紧凑苗条(slim):
import tensorflow as tf
slim = tf.contrib.slim
with slim.arg_scope(
[layers.conv2d],
padding='SAME',
initializer= xavier_initializer(),
regularizer= l2_regularizer(0.05)):
net = slim.conv2d(net, 256, [5, 5])
并不是所有的方法都能用arg_scope
设置默认参数, 只有用@slim.add_arg_scope
修饰过的方法才能使用arg_scope
. 例如示例中的conv2d
方法, 它就是被修饰过的(见源码).
所以, 要使slim.arg_scope
正常运行起来, 需要两个步骤:
@add_arg_scope
修饰目标函数with arg_scope(...)
设置默认参数.下面通过它的源码来分析一下它具体是如何工作的.
代码:
_ARGSTACK = [{}]
_DECORATED_OPS = {}
_ARGSTACK
被当作一个栈来使用, 用于存储所有的scope. 一个scope就是一个dict. 它的key是str(fn)
, 其中fn代表目标函数;value是一个字典, 存储着用户调用slim.arg_scope
时传入的参数. 之所以用栈, 是因为with是可以嵌套使用的. 栈顶存储的是最内层/最近的scope, 即当前scope.
例如,
@slim.add_arg_scope
def fn(a, b, c=3):
d = c + b
print("a={}, b={}".format(a, b))
return d
with slim.arg_scope([fn], a = 1):
fn(b = 2)
在上面的这段代码中, 执行到fn(b=2)
时, _ARGSTACK
的值为:
[{}, {'' : {'a': 1}}]
_DECORATED_OPS
是一个dict, 用于存储已经被@add_arg_scope
修饰过的函数信息. 它的key也是str(fn)
, value的生成机制却很难get到它的真实意图, 给我的感觉一是有错, 二是无用. 所以, 把_DECORATED_OPS
换成一个set
应该也是可以的.
仍然是上面那段代码, _DECORATED_OPS
的值为:
{'' : ()}
add_arg_scope
只保留了关键的代码:
def add_arg_scope(func):
"""Decorates a function with args so it can be used within an arg_scope.
...
"""
@functools.wraps(func)
def func_with_args(*args, **kwargs):
...
_add_op(func) # 在这里将函数放入`_DECORATED_OPS`里.
...
return func_with_args
_add_op
的操作真的很简单, 看看_DECORATED_OPS
里有没有目标函数, 没有就放进去:
def _add_op(op):
key_op = _key_op(op)
if key_op not in _DECORATED_OPS:
_DECORATED_OPS[key_op] = _kwarg_names(op)
def _key_op(op):
return getattr(op, '_key_op', str(op))
func_with_args
内容暂时略过, 后面会分析.
with arg_scope([], ...)
也是只保留了最关键的代码:
@contextlib.contextmanager
def arg_scope(list_ops_or_scope, **kwargs):
"""Stores the default arguments for the given set of list_ops.
...
"""
...
try:
current_scope = _current_arg_scope().copy()
...
if key_op in current_scope:
# 如果对一个函数多次应用arg_scope, 则更新已经存储的参数.
current_kwargs = current_scope[key_op].copy()
current_kwargs.update(kwargs)
current_scope[key_op] = current_kwargs
else:
# 如果是第一次对这个函数调用arg_scope, 则加进去.
current_scope[key_op] = kwargs.copy()
# 将当前scope放入栈顶.
_get_arg_stack().append(current_scope)
yield current_scope
finally:
_get_arg_stack().pop()
def _get_arg_stack():
if _ARGSTACK:
return _ARGSTACK
else:
_ARGSTACK.append({})
return _ARGSTACK
with
代码块内部调用目标函数之前省略的函数func_with_args
其实是关键中的关键: 在with
代码块内调用被@add_arg_scope
修饰的函数, 其实是调用了它, 而它负责从当前栈中查找用户设置的默认参数, 加上用户传入的调用参数, 调用真正的源函数.
func_with_args
的代码:
def func_with_args(*args, **kwargs):
current_scope = _current_arg_scope()# 当前栈
current_args = kwargs # 用户传入的关键字参数
key_func = _key_op(func)
if key_func in current_scope:
# 通过arg_scope设置的默认参数
current_args = current_scope[key_func].copy()
current_args.update(kwargs)
# 调用原函数
return func(*args, **current_args)
with
代码块弹出栈顶的scope
. 代码在arg_scope
方法最后一行.