Tensorflow-slim作为一种轻量级的tensorflow库,使得模型的构建,训练,测试都变得更加简单,其主要目的是来做所谓的“代码瘦身”。
在slim库中对很多常用的函数进行了定义,slim.arg_scope()是slim库中经常用到的函数之一。
slim.arg_scope常用于为tensorflow里的layer函数提供默认值,以使构建模型的代码更加紧凑苗条(slim)。
函数的定义如下:
def arg_scope(list_ops_or_scope, **kwargs)
list_ops_or_scope:要用的函数的作用域,可以在需要使用的地方用@add_arg_scope 声明
**kwargs: keyword=value 定义了list_ops中要使用的变量
如注释中所说,这个函数的作用是给list_ops中的内容设置默认值。但是每个list_ops中的每个成员需要用@add_arg_scope修饰才行。也就是说可以通过这个函数将不想重复写的参数通过这个函数自动赋值。
所以使用slim.arg_scope()有两个步骤:
- 使用@slim.add_arg_scope修饰目标函数
- 用 slim.arg_scope()为目标函数设置默认参数.
例如如下代码;首先用@slim.add_arg_scope修饰目标函数fun1(),然后利用slim.arg_scope()为它设置默认参数。
import tensorflow as tf
slim =tf.contrib.slim
@slim.add_arg_scope
def fun1(a=0,b=0):
return (a+b)
with slim.arg_scope([fun1],a=10):
x=fun1(b=30)
print(x)
运行结果为:
40
例2:
import tensorflow.contrib.slim as slim
@slim.add_arg_scope
def g(name, add_arg):
print("name:", name)
print("add_arg:", add_arg)
with slim.arg_scope([g], add_arg='this is add'):
g('test')
#结果:
#name: test
#add_arg: this is add
例3:
平常所用到的slim.conv2d( ),slim.fully_connected( ),slim.max_pool2d( )等函数在他被定义的时候就已经添加了@add_arg_scope。以slim.conv2d( )为例;
@add_arg_scope
def convolution(inputs,
num_outputs,
kernel_size,
stride=1,
padding='SAME',
data_format=None,
rate=1,
activation_fn=nn.relu,
normalizer_fn=None,
normalizer_params=None,
weights_initializer=initializers.xavier_initializer(),
weights_regularizer=None,
biases_initializer=init_ops.zeros_initializer(),
biases_regularizer=None,
reuse=None,
variables_collections=None,
outputs_collections=None,
trainable=True,
scope=None):
所以,在使用过程中可以直接slim.conv2d( )等函数设置默认参数。例如在下面的代码中,不做单独声明的情况下,slim.conv2d, slim.max_pool2d, slim.avg_pool2d三个函数默认的步长都设为1,padding模式都是'VALID'的。但是也可以在调用时进行单独声明。这种参数设置方式在构建网络模型时,尤其是较深的网络时,可以节省时间。
with slim.arg_scope(
[slim.conv2d, slim.max_pool2d, slim.avg_pool2d],stride = 1, padding = 'VALID'):
net = slim.conv2d(inputs, 32, [3, 3], stride = 2, scope = 'Conv2d_1a_3x3')
net = slim.conv2d(net, 32, [3, 3], scope = 'Conv2d_2a_3x3')
net = slim.conv2d(net, 64, [3, 3], padding = 'SAME', scope = 'Conv2d_2b_3x3')
其实这种用法是python中常用到的。在python中@修饰符放在函数定义的上方,它将被修饰的函数作为参数,并返回修饰后的同名函数。形式如下;
@fun_a #等价于fun_a(fun_b)
def fun_b():
这在本质上讲跟直接调用被修饰的函数没什么区别,但是有时候也有用处,例如在调用被修饰函数前需要输出时间信息,我们可以在@后方的函数中添加输出时间信息的语句,这样每次我们只需要调用@后方的函数即可。
import tensorflow as tf
slim =tf.contrib.slim
#用@slim.add_arg_scope修饰目标函数fun1()
@slim.add_arg_scope
def fun1(a=0,b=0):
return (a+b)
#用slim.arg_scope()为目标函数fun1()设置默认参数
with slim.arg_scope([fun1],a=20):
x=fun1(b=20)
print(x)
# 结果:40