本文探讨了一种新的深度学习算法——深度残差收缩网络(Deep Residual Shrinkage Network),加入了笔者自己的理解。
1.深度残差收缩网络的相关基础
从名字就能够看出,深度残差收缩网络是深度残差网络的一种改进方法。其特色是“收缩”,在这里指的是软阈值化,而软阈值化几乎是现在信号降噪算法的必备步骤。
因此,深度残差收缩网络是一种面向强噪数据的深度学习算法,是信号处理里的经典内容和深度学习、注意力机制的又一种结合。
深度残差收缩网络的基本模块如下图(a)所示,通过一个小型子网络,学习得到一组阈值,然后进行特征的软阈值化。同时,该模块还加入了恒等路径,以降低模型训练难度。深度残差收缩网络的整体结构如下图(b)所示,与通常的深度残差网络是一样的。
那么为何要进行收缩呢?收缩有什么好处呢?本文尝试从删除冗余特征的灵活度的角度,展开了讨论。
2.收缩(这里指软阈值化)
不了解软阈值化的同学可以去搜一下Soft Threshlding,在谷歌学术上会搜到这一篇:DL Donoho. De-noising by soft-thresholding[J]. IEEE transactions on information theory, 1995.
De-noising by soft-thresholding这一篇论文,目前的引用次数是12893次。可以看出来,软阈值化是一种经典的方法,尤其在信号降噪方面是非常常用的。
软阈值函数的式子如下:
其中t是阈值,是一个正数。从公式可以看出,软阈值化将[-t,t]区间内的特征置为0,将大于t的特征减t,将小于-t的特征加t。
如果用图片表示软阈值函数,就如下图所示:
3.收缩(这里指软阈值化)与ReLU激活函数的对比
软阈值化在深度残差收缩网络中是作为非线性映射,而现在深度学习最常用的非线性映射是ReLU激活函数。所以下面进行了两者的对比。
3.1 共同优点
我们首先分析一下,收缩(这里指软阈值化)和ReLU激活函数的共同优点。
首先,软阈值化和ReLU都可以将部分区间的特征置为0,相当于删除部分特征/信息。(可理解为,前面的层将冗余特征转换到某个取值区间,然后用软阈值化或ReLU进行删除)
其次,软阈值化和ReLU的梯度都要么为0,要么为1,都有利于梯度的反向传播。
3.2 收缩(这里指软阈值化)与ReLU的初步对比
相较于ReLU,软阈值化能够更加灵活地设置“待删除(置为0)”的特征取值区间。
我们首先独立地看ReLU,以下图为例。ReLU将低于0的特征,全部删除(置为0);大于0的特征,全部保留(保持不变)。
软阈值函数呢?它将某个区间,也就是[-阈值,阈值]这一区间内的特征删除(置为零);将这个区间之外的部分,包括大于阈值和小于-阈值的部分,保留下来(虽然朝向0进行了收缩)。下图展示了阈值t=10的情况:
在深度残差收缩网络中,阈值是可以通过注意力机制自动设置的。也就是说,[-阈值,阈值]的区间,是可以根据样本自身情况、自动调整的。
3.3 收缩(这里指软阈值化)与ReLU的深层对比
如果我们把ReLU和之前(卷积层或者批标准化里面的)偏置b,放在一起看呢?那么ReLU能够删除的特征取值空间,是可以变化的。比如说,将偏置b和ReLU作为一个整体的话,函数形式就变成了max(x+b,0)或者ReLU(x+b)。当偏置b为正数的时候,特征x会沿y轴向上平移,然后再将负特征置为0。例如,当b=20的时候,如下图所示:
或者当偏置b为负数的时候,特征x会沿y轴向下平移,然后再将负特征置为0。例如,当b=-20的时候,如下图所示:
接下来,我们来讨论软阈值函数。将偏置b和软阈值化作为一个整体的话,函数形式就变成了sign(x+b)•max(abs(x+b)-t,0)。当偏置b为正数的时候,首先特征x会沿y轴向上平移,然后再将零附近的特征置为0。例如,当偏置b=20、阈值t=10的时候,如下图所示:
当偏置b为负时,特征x会沿y轴向下平移,然后再将零附近的特征置为0。例如,当偏置b=-20、阈值t=10的时候,如下图所示:
在深度残差收缩网络中,由于偏置b和阈值t都是可以训练得到的参数,所以当偏置b和阈值t取值合适的时候,软阈值化是可以实现与ReLU相同的功能的。也就是,在现有的这些特征的[最小值,最大值]的范围内(不考虑无穷的情况,一般我们采集的数据不会有无穷),将低于某个值的特征全置为0,或者将高于某个值的特征全置为0。例如,在下图的数据中,如果我们将偏置b设置为20,将阈值t也设置为20,就将所有小于0的特征全部置为0了。因为没有小于-40的特征,所以“偏置+软阈值化”就相当于实现了ReLU的功能(将低于0的特征置为0)。
当然,由于[-阈值,阈值]区间和偏置b都是可调的,也可以是这样(b=40,t=20)(是不是和“偏置+ReLU”很相似):
然而,反过来的话,不管“偏置+ReLU”怎么组合,都无法实现下图中软阈值函数可以实现的功能。也就是,“偏置+ReLU”无法将某个区间内特征的置为0,并且同时保留大于上界和小于下界的特征。
从这个角度看的话,当和前一层的偏置放在一起看的时候,软阈值化比ReLU能够更加灵活地设置“待删除特征的取值区间”。
4.注意力机制的加持
更重要地,深度残差收缩网络采用了注意力机制(类似于Squeeze-and-Excitation Network)自动设置阈值,避免了人工设置阈值的麻烦。(人工设置阈值一直是一个大麻烦,而深度残差收缩网络用注意力机制解决了这个大麻烦)。
在注意力机制中,深度残差收缩网络采用了特殊的网络结构,保障了阈值不仅为正数,而且不会太大。因为如果阈值过大的话,就可能出现下图的情况,也就是所有特征都被置为0了。深度残差收缩网络的阈值,其实是(特征图的绝对值的平均值)×(0到1之间的系数),很好地避免了阈值太大的情况。
同时,深度残差收缩网络的阈值,是在注意力机制下,根据每个样本的情况,单独设置的。也就是,每个样本,都有自己的一组独特的阈值。因此,深度残差收缩网络适用于各个样本中噪声含量不同的情况。
5.深度残差收缩网络只适用于强噪声的数据吗?
我们在使用深度残差收缩网络的时候,似乎不需要考虑数据中是否真的含有很多噪声。换言之,深度残差收缩网络应该可以用于弱噪声的数据。
这是因为,深度残差收缩网络中的阈值,是根据样本自身的情况,通过一个小型子网络自动获得的。如果样本所含噪声很少,那么阈值可以被自动设置得很低(接近于0),从而“软阈值化”就退化成了“直接相等”。在这种情况下,软阈值化,就相当于不存在了。
6.恒等连接降低了训练难度
相较于普通的残差网络,深度残差收缩网络的结构较为复杂,所以恒等路径是有必要存在的。
7. 论文网址
M. Zhao, S. Zhong, X. Fu, et al., Deep residual shrinkage networks for fault diagnosis, IEEE Transactions on Industrial Informatics, DOI: 10.1109/TII.2019.2943898
https://ieeexplore.ieee.org/document/8850096
https://github.com/zhao62/Deep-Residual-Shrinkage-Networks
8. Keras示例代码
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Dec 28 23:24:05 2019
Implemented using TensorFlow 1.0.1 and Keras 2.2.1
M. Zhao, S. Zhong, X. Fu, et al., Deep Residual Shrinkage Networks for Fault Diagnosis,
IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898
@author: super_9527
"""
from __future__ import print_function
import keras
import numpy as np
from keras.datasets import mnist
from keras.layers import Dense, Conv2D, BatchNormalization, Activation
from keras.layers import AveragePooling2D, Input, GlobalAveragePooling2D
from keras.optimizers import Adam
from keras.regularizers import l2
from keras import backend as K
from keras.models import Model
from keras.layers.core import Lambda
K.set_learning_phase(1)
# Input image dimensions
img_rows, img_cols = 28, 28
# The data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
if K.image_data_format() == 'channels_first':
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)
else:
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
# Noised data
x_train = x_train.astype('float32') / 255. + 0.5*np.random.random([x_train.shape[0], img_rows, img_cols, 1])
x_test = x_test.astype('float32') / 255. + 0.5*np.random.random([x_test.shape[0], img_rows, img_cols, 1])
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)
def abs_backend(inputs):
return K.abs(inputs)
def expand_dim_backend(inputs):
return K.expand_dims(K.expand_dims(inputs,1),1)
def sign_backend(inputs):
return K.sign(inputs)
def pad_backend(inputs, in_channels, out_channels):
pad_dim = (out_channels - in_channels)//2
inputs = K.expand_dims(inputs,-1)
inputs = K.spatial_3d_padding(inputs, ((0,0),(0,0),(pad_dim,pad_dim)), 'channels_last')
return K.squeeze(inputs, -1)
# Residual Shrinakge Block
def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False,
downsample_strides=2):
residual = incoming
in_channels = incoming.get_shape().as_list()[-1]
for i in range(nb_blocks):
identity = residual
if not downsample:
downsample_strides = 1
residual = BatchNormalization()(residual)
residual = Activation('relu')(residual)
residual = Conv2D(out_channels, 3, strides=(downsample_strides, downsample_strides),
padding='same', kernel_initializer='he_normal',
kernel_regularizer=l2(1e-4))(residual)
residual = BatchNormalization()(residual)
residual = Activation('relu')(residual)
residual = Conv2D(out_channels, 3, padding='same', kernel_initializer='he_normal',
kernel_regularizer=l2(1e-4))(residual)
# Calculate global means
residual_abs = Lambda(abs_backend)(residual)
abs_mean = GlobalAveragePooling2D()(residual_abs)
# Calculate scaling coefficients
scales = Dense(out_channels, activation=None, kernel_initializer='he_normal',
kernel_regularizer=l2(1e-4))(abs_mean)
scales = BatchNormalization()(scales)
scales = Activation('relu')(scales)
scales = Dense(out_channels, activation='sigmoid', kernel_regularizer=l2(1e-4))(scales)
scales = Lambda(expand_dim_backend)(scales)
# Calculate thresholds
thres = keras.layers.multiply([abs_mean, scales])
# Soft thresholding
sub = keras.layers.subtract([residual_abs, thres])
zeros = keras.layers.subtract([sub, sub])
n_sub = keras.layers.maximum([sub, zeros])
residual = keras.layers.multiply([Lambda(sign_backend)(residual), n_sub])
# Downsampling using the pooL-size of (1, 1)
if downsample_strides > 1:
identity = AveragePooling2D(pool_size=(1,1), strides=(2,2))(identity)
# Zero_padding to match channels
if in_channels != out_channels:
identity = Lambda(pad_backend, arguments={'in_channels':in_channels,'out_channels':out_channels})(identity)
residual = keras.layers.add([residual, identity])
return residual
# define and train a model
inputs = Input(shape=input_shape)
net = Conv2D(8, 3, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(inputs)
net = residual_shrinkage_block(net, 1, 8, downsample=True)
net = BatchNormalization()(net)
net = Activation('relu')(net)
net = GlobalAveragePooling2D()(net)
outputs = Dense(10, activation='softmax', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(net)
model = Model(inputs=inputs, outputs=outputs)
model.compile(loss='categorical_crossentropy', optimizer=Adam(), metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=100, epochs=5, verbose=1, validation_data=(x_test, y_test))
# get results
K.set_learning_phase(0)
DRSN_train_score = model.evaluate(x_train, y_train, batch_size=100, verbose=0)
print('Train loss:', DRSN_train_score[0])
print('Train accuracy:', DRSN_train_score[1])
DRSN_test_score = model.evaluate(x_test, y_test, batch_size=100, verbose=0)
print('Test loss:', DRSN_test_score[0])
print('Test accuracy:', DRSN_test_score[1])
9. TFLearn程序
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Dec 23 21:23:09 2019
Implemented using TensorFlow 1.0 and TFLearn 0.3.2
M. Zhao, S. Zhong, X. Fu, B. Tang, M. Pecht, Deep Residual Shrinkage Networks for Fault Diagnosis,
IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898
@author: super_9527
"""
from __future__ import division, print_function, absolute_import
import tflearn
import numpy as np
import tensorflow as tf
from tflearn.layers.conv import conv_2d
# Data loading
from tflearn.datasets import cifar10
(X, Y), (testX, testY) = cifar10.load_data()
# Add noise
X = X + np.random.random((50000, 32, 32, 3))*0.1
testX = testX + np.random.random((10000, 32, 32, 3))*0.1
# Transform labels to one-hot format
Y = tflearn.data_utils.to_categorical(Y,10)
testY = tflearn.data_utils.to_categorical(testY,10)
def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False,
downsample_strides=2, activation='relu', batch_norm=True,
bias=True, weights_init='variance_scaling',
bias_init='zeros', regularizer='L2', weight_decay=0.0001,
trainable=True, restore=True, reuse=False, scope=None,
name="ResidualBlock"):
# residual shrinkage blocks with channel-wise thresholds
residual = incoming
in_channels = incoming.get_shape().as_list()[-1]
# Variable Scope fix for older TF
try:
vscope = tf.variable_scope(scope, default_name=name, values=[incoming],
reuse=reuse)
except Exception:
vscope = tf.variable_op_scope([incoming], scope, name, reuse=reuse)
with vscope as scope:
name = scope.name #TODO
for i in range(nb_blocks):
identity = residual
if not downsample:
downsample_strides = 1
if batch_norm:
residual = tflearn.batch_normalization(residual)
residual = tflearn.activation(residual, activation)
residual = conv_2d(residual, out_channels, 3,
downsample_strides, 'same', 'linear',
bias, weights_init, bias_init,
regularizer, weight_decay, trainable,
restore)
if batch_norm:
residual = tflearn.batch_normalization(residual)
residual = tflearn.activation(residual, activation)
residual = conv_2d(residual, out_channels, 3, 1, 'same',
'linear', bias, weights_init,
bias_init, regularizer, weight_decay,
trainable, restore)
# get thresholds and apply thresholding
abs_mean = tf.reduce_mean(tf.reduce_mean(tf.abs(residual),axis=2,keep_dims=True),axis=1,keep_dims=True)
scales = tflearn.fully_connected(abs_mean, out_channels//4, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
scales = tflearn.batch_normalization(scales)
scales = tflearn.activation(scales, 'relu')
scales = tflearn.fully_connected(scales, out_channels, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
scales = tf.expand_dims(tf.expand_dims(scales,axis=1),axis=1)
thres = tf.multiply(abs_mean,tflearn.activations.sigmoid(scales))
# soft thresholding
residual = tf.multiply(tf.sign(residual), tf.maximum(tf.abs(residual)-thres,0))
# Downsampling
if downsample_strides > 1:
identity = tflearn.avg_pool_2d(identity, 1,
downsample_strides)
# Projection to new dimension
if in_channels != out_channels:
if (out_channels - in_channels) % 2 == 0:
ch = (out_channels - in_channels)//2
identity = tf.pad(identity,
[[0, 0], [0, 0], [0, 0], [ch, ch]])
else:
ch = (out_channels - in_channels)//2
identity = tf.pad(identity,
[[0, 0], [0, 0], [0, 0], [ch, ch+1]])
in_channels = out_channels
residual = residual + identity
return residual
# Real-time data preprocessing
img_prep = tflearn.ImagePreprocessing()
img_prep.add_featurewise_zero_center(per_channel=True)
# Real-time data augmentation
img_aug = tflearn.ImageAugmentation()
img_aug.add_random_flip_leftright()
img_aug.add_random_crop([32, 32], padding=4)
# Build a Deep Residual Shrinkage Network with 3 blocks
net = tflearn.input_data(shape=[None, 32, 32, 3],
data_preprocessing=img_prep,
data_augmentation=img_aug)
net = tflearn.conv_2d(net, 16, 3, regularizer='L2', weight_decay=0.0001)
net = residual_shrinkage_block(net, 1, 16)
net = residual_shrinkage_block(net, 1, 32, downsample=True)
net = residual_shrinkage_block(net, 1, 32, downsample=True)
net = tflearn.batch_normalization(net)
net = tflearn.activation(net, 'relu')
net = tflearn.global_avg_pool(net)
# Regression
net = tflearn.fully_connected(net, 10, activation='softmax')
mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=20000, staircase=True)
net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy')
# Training
model = tflearn.DNN(net, checkpoint_path='model_cifar10',
max_checkpoints=10, tensorboard_verbose=0,
clip_gradients=0.)
model.fit(X, Y, n_epoch=100, snapshot_epoch=False, snapshot_step=500,
show_metric=True, batch_size=100, shuffle=True, run_id='model_cifar10')
training_acc = model.evaluate(X, Y)[0]
validation_acc = model.evaluate(testX, testY)[0]