model_variable
,来改变slim.conv2d
中参数的trainable
属性。不要问我怎么知道的。
arg_scope
的主要功能是为一些操作提供默认参数。arg_scope
可以叠加使用。def arg_scope(list_ops_or_scope, **kwargs):
# list_ops_or_scope: 需要添加默认参数的ops和scope
# **kwargs: 默认参数列表,如 param1='abcd'
arg_scope
from third_party.tensorflow.contrib.layers.python import layers
arg_scope = tf.contrib.framework.arg_scope
with arg_scope([layers.conv2d], padding='SAME',
initializer=layers.variance_scaling_initializer(),
regularizer=layers.l2_regularizer(0.05)):
net = layers.conv2d(inputs, 64, [11, 11], 4, padding='VALID', scope='conv1')
net = layers.conv2d(net, 256, [5, 5], scope='conv2')
# 其中,第一个conv2d相当于:
layers.conv2d(inputs, 64, [11, 11], 4, padding='VALID',
initializer=layers.variance_scaling_initializer(),
regularizer=layers.l2_regularizer(0.05), scope='conv1')
# 第二个conv2d相当于:
layers.conv2d(inputs, 256, [5, 5], padding='SAME',
initializer=layers.variance_scaling_initializer(),
regularizer=layers.l2_regularizer(0.05), scope='conv2')
arg_scope
with arg_scope([layers.conv2d], padding='SAME',
initializer=layers.variance_scaling_initializer(),
regularizer=layers.l2_regularizer(0.05)) as sc:
net = layers.conv2d(net, 256, [5, 5], scope='conv1')
....
# 可以直接通过arg_scope对象
with arg_scope(sc):
net = layers.conv2d(net, 256, [5, 5], scope='conv2')
arg_scope
中@tf.contrib.framework.add_arg_scope
def conv2d(*args, **kwargs)
@slim.add_arg_scope
修饰目标操作(即函数)。with slim.arg_scope(...):
设置默认参数。import tensorflow.contrib.slim as slim
# 定义一些自己的函数,并使用add_arg_scope装饰
@slim.add_arg_scope
def my_function1(param1='param1', param2='param2'):
print('my_function1', param1, param2)
@slim.add_arg_scope
def my_function2(param2='param2', param3='param3'):
print('my_function2', param2, param3)
# 操作一
# 普通的调用arg_scope的方式
with slim.arg_scope([my_function1], param2='param2_modify'):
my_function1()
# 输出 my_function1 param1 param2_modify
# 操作二
# 如果设置的一些参数,不存在于ops列表中,则会报错
# 以下实例中,my_function1函数中没有参数param3,所以会报错
with slim.arg_scope([my_function1], param3='param2_modify'):
my_function1()
# 操作三
# 优先级
# 1. 优先级最高的是函数本身的实参
# 2. 有多层arg_scope时,优先级最高的时最内侧的arg_scope
# 3. 优先级最低的是函数定义中的默认参数
with slim.arg_scope([my_function1], param2='param2_modify1'):
with slim.arg_scope([my_function1], param2='param2_modify2'):
my_function1(param2=2)
# 输出为 my_function1 param1 2
# 操作四
# 当命名参数与非命名参数同时使用时要注意
# 如函数 my_function1中的参数列表为param1, param2
# 则以下函数会报错
with slim.arg_scope([my_function1], param1=1):
my_function1(1, 2)
# 错误 my_function1() got multiple values for argument 'param1'
# 错误分析:从源码来看,非命名参数与命名参数是分开处理的,所以以上代码等价于
# my_function1(1, 2, param1=1)
# 操作五
# 对于以arg_scope作为参数传递时
with slim.arg_scope([my_function1], param2=22) as s:
pass
with slim.arg_scope([my_function1], param1=111):
with slim.arg_scope(s):
my_function1()
# 输出 myfunction1 param1 22
# 由此可见,最外层的arg_scope完全不起作用
with slim.arg_scope([my_function2], param3=123):
with slim.arg_scope(s):
my_function2()
# 输出 my_function2() param2 param3
# 由此可见,最外层arg_scope完全没有效果
# 即使s对象完全没有对my_function2进行操作,最外层arg_scope也没有起作用
# 列表,实现"栈"结构,用于多层arg_scope
# 列表中每个元素为字典,代表一层arg_scope,且融合了之前所有arg_scope的内容
# (有特例,就是使用arg_scope对象、即字典对象,作为参数传递到arg_scope中,具体参考arg_scope函数源码)
# 元素key为str(func)
# 元素value为字典:即 参数名(字符串) -> 默认参数值
_ARGSTACK = [{}]
# 字典,用于存储所有被@add_arg_scope修饰的函数
# key为函数名称,即str(func)
# value为函数命名参数列表
_DECORATED_OPS = {}
# 获取 _ARGSTACK 对象
def _get_arg_stack():
if _ARGSTACK:
return _ARGSTACK
else:
_ARGSTACK.append({})
return _ARGSTACK
# 获取函数的属性'_key_op'的值,该属性不存在,则返回str(op)
def _key_op(op):
return getattr(op, '_key_op', str(op))
# 获取当前函数的模块名称与函数名称
def _name_op(op):
return (op.__module__, op.__name__)
# 获取当前函数所有命名参数(有默认数值属性)的列表
def _kwarg_names(func):
kwargs_length = len(func.__defaults__) if func.__defaults__ else 0
return func.__code__.co_varnames[-kwargs_length:func.__code__.co_argcount]
# 将当前函数添加到 _DECORATED_OPS 中
# 从这儿可以看出,_DECORATED_OPS的key为函数名,value为有默认属性值命名参数的列表
def _add_op(op):
key_op = _key_op(op)
if key_op not in _DECORATED_OPS:
_DECORATED_OPS[key_op] = _kwarg_names(op)
# 获取当前arg_scope
# arg_scope可以叠加使用,该方法获取最内侧的arg_scope对象
# 换句话说,获取栈顶部的arg_scope字典对象
def current_arg_scope():
stack = _get_arg_stack()
return stack[-1]
# 最重要的函数,作用为:为选定的操作,添加默认参数
# 根据定义可以看出,该函数使用了上下文管理器,简单说就是为了使用 with 而必须的操作。
@tf_contextlib.contextmanager
def arg_scope(list_ops_or_scope, **kwargs):
if isinstance(list_ops_or_scope, dict):
# 当输入的list_ops_or_scope为其他arg_scope对象时,就会输入字典对象
if kwargs:
# 当list_ops_or_scope为字典对象时,kwargs必须为空,否则报错
raise ValueError('When attempting to re-use a scope by suppling a'
'dictionary, kwargs must be empty.')
current_scope = list_ops_or_scope.copy()
try:
# 从以下代码可以看出,当使用arg_scope对象作为参数时,之前所有arg_scope对象都不起作用了
# 即1.5.节中的操作五
_get_arg_stack().append(current_scope)
yield current_scope
finally:
_get_arg_stack().pop()
else:
# 当输入 list_ops_or_scope 为list或tuple时
if not isinstance(list_ops_or_scope, (list, tuple)):
raise TypeError('list_ops_or_scope must either be a list/tuple or reused'
'scope (i.e. dict)')
try:
# 第一步:复制前一层arg_scope中所有操作的默认参数
current_scope = current_arg_scope().copy()
# 第二步:遍历所有op(即函数),判断该函数是否被@add_arg_scope修饰,
# 最后将当前arg_scope中的参数列表与之前arg_scope中的参数列表合并
for op in list_ops_or_scope:
key_op = _key_op(op)
# 判断是否被@add_arg_scope修饰
if not has_arg_scope(op):
raise ValueError('%s is not decorated with @add_arg_scope',
_name_op(op))
# 合并参数列表
if key_op in current_scope:
current_kwargs = current_scope[key_op].copy()
current_kwargs.update(kwargs)
current_scope[key_op] = current_kwargs
else:
current_scope[key_op] = kwargs.copy()
# 第三步:将当前arg_scope对象放置到栈中
_get_arg_stack().append(current_scope)
# 上下文管理器标志位,之前都是__enter__操作,之后都是__exit__操作
yield current_scope
finally:
# 离开with块时,将当前arg_scope删除
_get_arg_stack().pop()
# 该函数设计到Python中的 修饰器
# 为了使得arg_socpe有效,相关操作(函数)必须使用@add_arg_scope修饰
# 主要作用:将函数添加到内部数据结构中,使得可以通过arg_scope操作
def add_arg_scope(func):
# 用以下函数修饰输入函数func
# 从该函数的输入列表中可以看出,命名参数与非命名参数是分开处理的
# 且合并的参数仅仅为命名参数(不对非命名参数进行处理)
# 即1.5.节中的操作四
def func_with_args(*args, **kwargs):
current_scope = current_arg_scope()
current_args = kwargs
key_func = _key_op(func)
if key_func in current_scope:
current_args = current_scope[key_func].copy()
current_args.update(kwargs)
return func(*args, **current_args)
# 被修饰的函数,都要被添加到_DECORATED_OPS中
_add_op(func)
setattr(func_with_args, '_key_op', _key_op(func))
return tf_decorator.make_decorator(func, func_with_args)
# 查看一个函数是否被add_arg_scope修饰
def has_arg_scope(func):
return _key_op(func) in _DECORATED_OPS
# 通过arc_scope可为该函数设置哪些变量的默认参数
# 起始该方法不准确,因为也可以通过arg_scope设置一些非命名参数的默认值,但容易出错……
def arg_scoped_arguments(func):
assert has_arg_scope(func)
return _DECORATED_OPS[_key_op(func)]
原文:https://zhuanlan.zhihu.com/p/33848199