Inside TF-Slim(2) arg_scope

0. 前言

  • arg_scope 源码地址
  • 参考博客:slim.arg_scope原理分析。
  • 源码使用到了Python特性中的 装饰器(即类似Java注释) 和 上下文管理(即with语句)
  • 其他:
    • 不能通过指定model_variable,来改变slim.conv2d中参数的trainable属性。不要问我怎么知道的。

 

 

1. 基本功能与使用

1.1. 基本功能

  • arg_scope的主要功能是为一些操作提供默认参数。
  • arg_scope可以叠加使用。
  • API定义:
def arg_scope(list_ops_or_scope, **kwargs):
# list_ops_or_scope: 需要添加默认参数的ops和scope
# **kwargs: 默认参数列表,如 param1='abcd' 

1.2. 官方实例

  • 官方实例1,如何使用arg_scope
from third_party.tensorflow.contrib.layers.python import layers
  arg_scope = tf.contrib.framework.arg_scope
  with arg_scope([layers.conv2d], padding='SAME',
                 initializer=layers.variance_scaling_initializer(),
                 regularizer=layers.l2_regularizer(0.05)):
    net = layers.conv2d(inputs, 64, [11, 11], 4, padding='VALID', scope='conv1')
    net = layers.conv2d(net, 256, [5, 5], scope='conv2')
    
# 其中,第一个conv2d相当于:
layers.conv2d(inputs, 64, [11, 11], 4, padding='VALID',
              initializer=layers.variance_scaling_initializer(),
              regularizer=layers.l2_regularizer(0.05), scope='conv1')
              
# 第二个conv2d相当于:
layers.conv2d(inputs, 256, [5, 5], padding='SAME',
              initializer=layers.variance_scaling_initializer(),
              regularizer=layers.l2_regularizer(0.05), scope='conv2')
  • 官方实例2,如何复用arg_scope
with arg_scope([layers.conv2d], padding='SAME',
                 initializer=layers.variance_scaling_initializer(),
                 regularizer=layers.l2_regularizer(0.05)) as sc:
    net = layers.conv2d(net, 256, [5, 5], scope='conv1')
    ....

# 可以直接通过arg_scope对象
with arg_scope(sc):
    net = layers.conv2d(net, 256, [5, 5], scope='conv2') 
  • 官方实例3,如何令自己创建的函数也可用于arg_scope
@tf.contrib.framework.add_arg_scope
def conv2d(*args, **kwargs)

1.3. 使用总结

  • 使用步骤:
  1. 使用@slim.add_arg_scope修饰目标操作(即函数)。
  2. 使用with slim.arg_scope(...):设置默认参数。
  • 自己的一些测试:
import tensorflow.contrib.slim as slim

# 定义一些自己的函数,并使用add_arg_scope装饰
@slim.add_arg_scope
def my_function1(param1='param1', param2='param2'):
    print('my_function1', param1, param2)
@slim.add_arg_scope
def my_function2(param2='param2', param3='param3'):
    print('my_function2', param2, param3)
  
# 操作一  
# 普通的调用arg_scope的方式
with slim.arg_scope([my_function1], param2='param2_modify'):
    my_function1()
    # 输出 my_function1 param1 param2_modify

# 操作二
# 如果设置的一些参数,不存在于ops列表中,则会报错
# 以下实例中,my_function1函数中没有参数param3,所以会报错
with slim.arg_scope([my_function1], param3='param2_modify'):
    my_function1()

# 操作三
# 优先级
# 1. 优先级最高的是函数本身的实参
# 2. 有多层arg_scope时,优先级最高的时最内侧的arg_scope
# 3. 优先级最低的是函数定义中的默认参数
with slim.arg_scope([my_function1], param2='param2_modify1'):
    with slim.arg_scope([my_function1], param2='param2_modify2'):
        my_function1(param2=2)
        # 输出为 my_function1 param1 2

# 操作四
# 当命名参数与非命名参数同时使用时要注意
# 如函数 my_function1中的参数列表为param1, param2
# 则以下函数会报错
with slim.arg_scope([my_function1], param1=1):
    my_function1(1, 2)
    # 错误 my_function1() got multiple values for argument 'param1'
    # 错误分析:从源码来看,非命名参数与命名参数是分开处理的,所以以上代码等价于
    # my_function1(1, 2, param1=1)

# 操作五
# 对于以arg_scope作为参数传递时
with slim.arg_scope([my_function1], param2=22) as s:
    pass

with slim.arg_scope([my_function1], param1=111):
    with slim.arg_scope(s):
        my_function1()
        # 输出 myfunction1 param1 22
        # 由此可见,最外层的arg_scope完全不起作用

with slim.arg_scope([my_function2], param3=123):
    with slim.arg_scope(s):
        my_function2()
        # 输出 my_function2() param2 param3
        # 由此可见,最外层arg_scope完全没有效果
        # 即使s对象完全没有对my_function2进行操作,最外层arg_scope也没有起作用

 

2.源码理解

2.1. arg_scope数据存储介绍

# 列表,实现"栈"结构,用于多层arg_scope
# 列表中每个元素为字典,代表一层arg_scope,且融合了之前所有arg_scope的内容
# (有特例,就是使用arg_scope对象、即字典对象,作为参数传递到arg_scope中,具体参考arg_scope函数源码)
# 元素key为str(func)
# 元素value为字典:即 参数名(字符串) -> 默认参数值
_ARGSTACK = [{}]

# 字典,用于存储所有被@add_arg_scope修饰的函数
# key为函数名称,即str(func)
# value为函数命名参数列表
_DECORATED_OPS = {}

2.2. 函数功能介绍

  • 私有函数
# 获取 _ARGSTACK 对象
def _get_arg_stack():
  if _ARGSTACK:
    return _ARGSTACK
  else:
    _ARGSTACK.append({})
    return _ARGSTACK

# 获取函数的属性'_key_op'的值,该属性不存在,则返回str(op)
def _key_op(op):
  return getattr(op, '_key_op', str(op))


# 获取当前函数的模块名称与函数名称
def _name_op(op):
  return (op.__module__, op.__name__)

# 获取当前函数所有命名参数(有默认数值属性)的列表
def _kwarg_names(func):
  kwargs_length = len(func.__defaults__) if func.__defaults__ else 0
  return func.__code__.co_varnames[-kwargs_length:func.__code__.co_argcount]

# 将当前函数添加到 _DECORATED_OPS 中
# 从这儿可以看出,_DECORATED_OPS的key为函数名,value为有默认属性值命名参数的列表
def _add_op(op):
  key_op = _key_op(op)
  if key_op not in _DECORATED_OPS:
    _DECORATED_OPS[key_op] = _kwarg_names(op)
  • 共有函数
# 获取当前arg_scope
# arg_scope可以叠加使用,该方法获取最内侧的arg_scope对象
# 换句话说,获取栈顶部的arg_scope字典对象
def current_arg_scope():
  stack = _get_arg_stack()
  return stack[-1]

# 最重要的函数,作用为:为选定的操作,添加默认参数
# 根据定义可以看出,该函数使用了上下文管理器,简单说就是为了使用 with 而必须的操作。
@tf_contextlib.contextmanager
def arg_scope(list_ops_or_scope, **kwargs):
  if isinstance(list_ops_or_scope, dict):
    # 当输入的list_ops_or_scope为其他arg_scope对象时,就会输入字典对象
    if kwargs:
      # 当list_ops_or_scope为字典对象时,kwargs必须为空,否则报错
      raise ValueError('When attempting to re-use a scope by suppling a'
                       'dictionary, kwargs must be empty.')
    current_scope = list_ops_or_scope.copy()
    try:
    
      # 从以下代码可以看出,当使用arg_scope对象作为参数时,之前所有arg_scope对象都不起作用了
      # 即1.5.节中的操作五
      _get_arg_stack().append(current_scope)
      yield current_scope
    finally:
      _get_arg_stack().pop()
      
      
  else:
    # 当输入 list_ops_or_scope 为list或tuple时
    if not isinstance(list_ops_or_scope, (list, tuple)):
      raise TypeError('list_ops_or_scope must either be a list/tuple or reused'
                      'scope (i.e. dict)')
    try:
      # 第一步:复制前一层arg_scope中所有操作的默认参数
      current_scope = current_arg_scope().copy()
      
      # 第二步:遍历所有op(即函数),判断该函数是否被@add_arg_scope修饰,
      #         最后将当前arg_scope中的参数列表与之前arg_scope中的参数列表合并
      for op in list_ops_or_scope:
        key_op = _key_op(op)
        
        # 判断是否被@add_arg_scope修饰
        if not has_arg_scope(op):
          raise ValueError('%s is not decorated with @add_arg_scope',
                           _name_op(op))
        
        # 合并参数列表
        if key_op in current_scope:
          current_kwargs = current_scope[key_op].copy()
          current_kwargs.update(kwargs)
          current_scope[key_op] = current_kwargs
        else:
          current_scope[key_op] = kwargs.copy()
      
      # 第三步:将当前arg_scope对象放置到栈中
      _get_arg_stack().append(current_scope)
      
      # 上下文管理器标志位,之前都是__enter__操作,之后都是__exit__操作
      yield current_scope
    finally:
      # 离开with块时,将当前arg_scope删除
      _get_arg_stack().pop()



# 该函数设计到Python中的 修饰器
# 为了使得arg_socpe有效,相关操作(函数)必须使用@add_arg_scope修饰
# 主要作用:将函数添加到内部数据结构中,使得可以通过arg_scope操作
def add_arg_scope(func):
  # 用以下函数修饰输入函数func
  # 从该函数的输入列表中可以看出,命名参数与非命名参数是分开处理的
  # 且合并的参数仅仅为命名参数(不对非命名参数进行处理)
  # 即1.5.节中的操作四
  def func_with_args(*args, **kwargs):
    current_scope = current_arg_scope()
    current_args = kwargs
    key_func = _key_op(func)
    if key_func in current_scope:
      current_args = current_scope[key_func].copy()
      current_args.update(kwargs)
    return func(*args, **current_args)

  # 被修饰的函数,都要被添加到_DECORATED_OPS中
  _add_op(func)
  setattr(func_with_args, '_key_op', _key_op(func))
  return tf_decorator.make_decorator(func, func_with_args)

# 查看一个函数是否被add_arg_scope修饰
def has_arg_scope(func):
  return _key_op(func) in _DECORATED_OPS

# 通过arc_scope可为该函数设置哪些变量的默认参数
# 起始该方法不准确,因为也可以通过arg_scope设置一些非命名参数的默认值,但容易出错……
def arg_scoped_arguments(func):
  assert has_arg_scope(func)
  return _DECORATED_OPS[_key_op(func)]

 原文:https://zhuanlan.zhihu.com/p/33848199

你可能感兴趣的:(Tensorflow)