【陪你聊TensorLayer-3】TensorLayer中ResNet的实现

众所周知,ResNet是在15年以压倒性优势夺得ImageNet冠军的网络结构,他的模块结构如下:

【陪你聊TensorLayer-3】TensorLayer中ResNet的实现_第1张图片

他的主要目的是将浅层输入直接作用到深层,能够减缓梯度消失问题,达到加深网络深度以加强网络性能的目的。那么,在TensorLayer中该如何进行实现呢?

 
  
import tensorflow as tf
import tensorlayer as tl
from tensorlayer.layers import *
w_init = tf.random_normal_initializer(stddev=0.02)
b_init = None  # tf.constant_initializer(value=0.0)
g_init = tf.random_normal_initializer(1., 0.02)
with tf.variable_scope("SRGAN_g"):
    n = InputLayer(t_image, name='in')
    n = Conv2d(n, 64, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', W_init=w_init, name='n64s1/c')
    temp = n

    # B residual blocks
    for i in range(16):
        nn = Conv2d(n, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c1/%s' % i)
        nn = BatchNormLayer(nn, act=tf.nn.relu, is_train=is_train, gamma_init=g_init, name='n64s1/b1/%s' % i)
        nn = Conv2d(nn, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c2/%s' % i)
        nn = BatchNormLayer(nn, is_train=is_train, gamma_init=g_init, name='n64s1/b2/%s' % i)
        nn = ElementwiseLayer([n, nn], tf.add, name='b_residual_add/%s' % i)
        n = nn

    n = Conv2d(n, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c/m')
    n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n64s1/b/m')
    n = ElementwiseLayer([n, temp], tf.add, name='add3')
    # B residual blacks end

    n = Conv2d(n, 256, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n256s1/1')
    n = SubpixelConv2d(n, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshufflerx2/1')

    n = Conv2d(n, 256, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n256s1/2')
    n = SubpixelConv2d(n, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshufflerx2/2')

    n = Conv2d(n, 3, (1, 1), (1, 1), act=tf.nn.tanh, padding='SAME', W_init=w_init, name='out')
上面是从源码中截取的一段关于ResNet在TensorLayer中的实现方式,这里该作者在ResNet中使用了BN(批归一化)操作,是B atch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift这篇论文提出来的,但是在周博磊大神的论文中证明了这种方法似乎会损害模型的整体表达能力,但是它的作用也是非常强大的,这里就不做详述。我们回到ResNet模型中,可以看到在for循环中是ResNet的主体,一共有16个残差模块,每一个残差模块都由两个卷积层组成,残差连接在代码中用ElementwiseLayer函数传入tf.add参数完成,大家有兴趣的话可以用该模型替换在【陪你聊TensorLayer-1】给出的Mnist数据集处理的实例中的网络看看其效果。好了,Bye~

你可能感兴趣的:(深度学习)