残差网络为何有效,都有哪些发展? - amaze2的回答 - 知乎 https://www.zhihu.com/questio...
高票回答已经很清楚地介绍了残差网络,下面补充一个比较新颖的残差网络的改进,也就是残差收缩网络Residual Shrinkage Network。
1.动机(冗余信息无处不在)
在开展机器学习任务的时候,我们的数据集中往往或多或少地包含着一定的冗余信息。这些冗余信息会对深度神经网络的特征学习效果造成不利的影响。
因此,从这个角度讲的话,我们在设计深度神经网络的时候,或许应该刻意增强深度神经网络剔除冗余信息的能力。
2.残差收缩网络的基本模块
如下图所示,残差收缩网络在其基本模块中加入了一个子网络,来学习得到一组阈值,然后对残差路径进行软阈值化(即“收缩”)。这个过程可以看成是一种非常灵活的、删除冗余信息的方式。
3.软阈值化的优势
(1)灵活地删除冗余信息。软阈值化能够将位于[-τ, τ]区间的特征置为零,将其他特征也朝着零的方向进行收缩。如果和前一个卷积层的偏置b放在一起看的话,置为零的区间其实就是[-τ+b, τ+b]。在这里,τ和b都是可训练的参数。在这里,软阈值化其实可以将任意区间的特征置为零,也就是删除掉,是一种非常灵活的、删除冗余信息的方式。
(2)梯度要么为零,要么为一。这个特点是和ReLU激活函数一样的,有利于减小梯度消失和梯度爆照的风险。
4.子网络与软阈值化的结合
软阈值化中的阈值设置是一个难题。在残差收缩网络中,阈值是通过一个子网络自动获得的,不需要人工设置,避免了这个难题,而且还带来了以下优点:
(1)每个样本可以有自己独特的阈值。一个数据集中,可能有的样本冗余信息较多,有些样本冗余信息较少,那么它们的阈值应该是有所不同的。残差收缩网络借助这个子网络,能够给各个样本赋予不同的阈值。
(2)阈值为正数,且不会太大。在软阈值化中,阈值必须是正的,而且不能太大,否则输出会全部为0。残差收缩网络的基本模块经过专门的设计,能满足这一条件。
5.整体结构
6.简单程序
将残差收缩网络用于加噪MNIST图像的分类,代码如下(仅供参考):
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Dec 28 23:24:05 2019
Implemented using TensorFlow 1.0.1 and Keras 2.2.1
M. Zhao, S. Zhong, X. Fu, et al., Deep Residual Shrinkage Networks for Fault Diagnosis,
IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898
@author: super_9527
"""
from __future__ import print_function
import keras
import numpy as np
from keras.datasets import mnist
from keras.layers import Dense, Conv2D, BatchNormalization, Activation
from keras.layers import AveragePooling2D, Input, GlobalAveragePooling2D
from keras.optimizers import Adam
from keras.regularizers import l2
from keras import backend as K
from keras.models import Model
from keras.layers.core import Lambda
K.set_learning_phase(1)
# Input image dimensions
img_rows, img_cols = 28, 28
# The data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
if K.image_data_format() == 'channels_first':
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)
else:
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
# Noised data
x_train = x_train.astype('float32') / 255. + 0.5*np.random.random([x_train.shape[0], img_rows, img_cols, 1])
x_test = x_test.astype('float32') / 255. + 0.5*np.random.random([x_test.shape[0], img_rows, img_cols, 1])
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)
def abs_backend(inputs):
return K.abs(inputs)
def expand_dim_backend(inputs):
return K.expand_dims(K.expand_dims(inputs,1),1)
def sign_backend(inputs):
return K.sign(inputs)
def pad_backend(inputs, in_channels, out_channels):
pad_dim = (out_channels - in_channels)//2
inputs = K.expand_dims(inputs,-1)
inputs = K.spatial_3d_padding(inputs, ((0,0),(0,0),(pad_dim,pad_dim)), 'channels_last')
return K.squeeze(inputs, -1)
# Residual Shrinakge Block
def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False,
downsample_strides=2):
residual = incoming
in_channels = incoming.get_shape().as_list()[-1]
for i in range(nb_blocks):
identity = residual
if not downsample:
downsample_strides = 1
residual = BatchNormalization()(residual)
residual = Activation('relu')(residual)
residual = Conv2D(out_channels, 3, strides=(downsample_strides, downsample_strides),
padding='same', kernel_initializer='he_normal',
kernel_regularizer=l2(1e-4))(residual)
residual = BatchNormalization()(residual)
residual = Activation('relu')(residual)
residual = Conv2D(out_channels, 3, padding='same', kernel_initializer='he_normal',
kernel_regularizer=l2(1e-4))(residual)
# Calculate global means
residual_abs = Lambda(abs_backend)(residual)
abs_mean = GlobalAveragePooling2D()(residual_abs)
# Calculate scaling coefficients
scales = Dense(out_channels, activation=None, kernel_initializer='he_normal',
kernel_regularizer=l2(1e-4))(abs_mean)
scales = BatchNormalization()(scales)
scales = Activation('relu')(scales)
scales = Dense(out_channels, activation='sigmoid', kernel_regularizer=l2(1e-4))(scales)
scales = Lambda(expand_dim_backend)(scales)
# Calculate thresholds
thres = keras.layers.multiply([abs_mean, scales])
# Soft thresholding
sub = keras.layers.subtract([residual_abs, thres])
zeros = keras.layers.subtract([sub, sub])
n_sub = keras.layers.maximum([sub, zeros])
residual = keras.layers.multiply([Lambda(sign_backend)(residual), n_sub])
# Downsampling using the pooL-size of (1, 1)
if downsample_strides > 1:
identity = AveragePooling2D(pool_size=(1,1), strides=(2,2))(identity)
# Zero_padding to match channels
if in_channels != out_channels:
identity = Lambda(pad_backend, arguments={'in_channels':in_channels,'out_channels':out_channels})(identity)
residual = keras.layers.add([residual, identity])
return residual
# define and train a model
inputs = Input(shape=input_shape)
net = Conv2D(8, 3, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(inputs)
net = residual_shrinkage_block(net, 1, 8, downsample=True)
net = BatchNormalization()(net)
net = Activation('relu')(net)
net = GlobalAveragePooling2D()(net)
outputs = Dense(10, activation='softmax', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(net)
model = Model(inputs=inputs, outputs=outputs)
model.compile(loss='categorical_crossentropy', optimizer=Adam(), metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=100, epochs=5, verbose=1, validation_data=(x_test, y_test))
# get results
K.set_learning_phase(0)
DRSN_train_score = model.evaluate(x_train, y_train, batch_size=100, verbose=0)
print('Train loss:', DRSN_train_score[0])
print('Train accuracy:', DRSN_train_score[1])
DRSN_test_score = model.evaluate(x_test, y_test, batch_size=100, verbose=0)
print('Test loss:', DRSN_test_score[0])
print('Test accuracy:', DRSN_test_score[1])
TFLearn版代码:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Dec 23 21:23:09 2019
Implemented using TensorFlow 1.0 and TFLearn 0.3.2
M. Zhao, S. Zhong, X. Fu, B. Tang, M. Pecht, Deep Residual Shrinkage Networks for Fault Diagnosis,
IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898
@author: super_9527
"""
from __future__ import division, print_function, absolute_import
import tflearn
import numpy as np
import tensorflow as tf
from tflearn.layers.conv import conv_2d
# Data loading
from tflearn.datasets import cifar10
(X, Y), (testX, testY) = cifar10.load_data()
# Add noise
X = X + np.random.random((50000, 32, 32, 3))*0.1
testX = testX + np.random.random((10000, 32, 32, 3))*0.1
# Transform labels to one-hot format
Y = tflearn.data_utils.to_categorical(Y,10)
testY = tflearn.data_utils.to_categorical(testY,10)
def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False,
downsample_strides=2, activation='relu', batch_norm=True,
bias=True, weights_init='variance_scaling',
bias_init='zeros', regularizer='L2', weight_decay=0.0001,
trainable=True, restore=True, reuse=False, scope=None,
name="ResidualBlock"):
# residual shrinkage blocks with channel-wise thresholds
residual = incoming
in_channels = incoming.get_shape().as_list()[-1]
# Variable Scope fix for older TF
try:
vscope = tf.variable_scope(scope, default_name=name, values=[incoming],
reuse=reuse)
except Exception:
vscope = tf.variable_op_scope([incoming], scope, name, reuse=reuse)
with vscope as scope:
name = scope.name #TODO
for i in range(nb_blocks):
identity = residual
if not downsample:
downsample_strides = 1
if batch_norm:
residual = tflearn.batch_normalization(residual)
residual = tflearn.activation(residual, activation)
residual = conv_2d(residual, out_channels, 3,
downsample_strides, 'same', 'linear',
bias, weights_init, bias_init,
regularizer, weight_decay, trainable,
restore)
if batch_norm:
residual = tflearn.batch_normalization(residual)
residual = tflearn.activation(residual, activation)
residual = conv_2d(residual, out_channels, 3, 1, 'same',
'linear', bias, weights_init,
bias_init, regularizer, weight_decay,
trainable, restore)
# get thresholds and apply thresholding
abs_mean = tf.reduce_mean(tf.reduce_mean(tf.abs(residual),axis=2,keep_dims=True),axis=1,keep_dims=True)
scales = tflearn.fully_connected(abs_mean, out_channels//4, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
scales = tflearn.batch_normalization(scales)
scales = tflearn.activation(scales, 'relu')
scales = tflearn.fully_connected(scales, out_channels, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
scales = tf.expand_dims(tf.expand_dims(scales,axis=1),axis=1)
thres = tf.multiply(abs_mean,tflearn.activations.sigmoid(scales))
# soft thresholding
residual = tf.multiply(tf.sign(residual), tf.maximum(tf.abs(residual)-thres,0))
# Downsampling
if downsample_strides > 1:
identity = tflearn.avg_pool_2d(identity, 1,
downsample_strides)
# Projection to new dimension
if in_channels != out_channels:
if (out_channels - in_channels) % 2 == 0:
ch = (out_channels - in_channels)//2
identity = tf.pad(identity,
[[0, 0], [0, 0], [0, 0], [ch, ch]])
else:
ch = (out_channels - in_channels)//2
identity = tf.pad(identity,
[[0, 0], [0, 0], [0, 0], [ch, ch+1]])
in_channels = out_channels
residual = residual + identity
return residual
# Real-time data preprocessing
img_prep = tflearn.ImagePreprocessing()
img_prep.add_featurewise_zero_center(per_channel=True)
# Real-time data augmentation
img_aug = tflearn.ImageAugmentation()
img_aug.add_random_flip_leftright()
img_aug.add_random_crop([32, 32], padding=4)
# Build a Deep Residual Shrinkage Network with 3 blocks
net = tflearn.input_data(shape=[None, 32, 32, 3],
data_preprocessing=img_prep,
data_augmentation=img_aug)
net = tflearn.conv_2d(net, 16, 3, regularizer='L2', weight_decay=0.0001)
net = residual_shrinkage_block(net, 1, 16)
net = residual_shrinkage_block(net, 1, 32, downsample=True)
net = residual_shrinkage_block(net, 1, 32, downsample=True)
net = tflearn.batch_normalization(net)
net = tflearn.activation(net, 'relu')
net = tflearn.global_avg_pool(net)
# Regression
net = tflearn.fully_connected(net, 10, activation='softmax')
mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=20000, staircase=True)
net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy')
# Training
model = tflearn.DNN(net, checkpoint_path='model_cifar10',
max_checkpoints=10, tensorboard_verbose=0,
clip_gradients=0.)
model.fit(X, Y, n_epoch=100, snapshot_epoch=False, snapshot_step=500,
show_metric=True, batch_size=100, shuffle=True, run_id='model_cifar10')
training_acc = model.evaluate(X, Y)[0]
validation_acc = model.evaluate(testX, testY)[0]
Minghang Zhao, Shisheng Zhong, Xuyun Fu, Baoping Tang, Michael Pecht, Deep residual shrinkage networks for fault diagnosis, IEEE Transactions on Industrial Informatics, 2020, 16(7): 4681-4690.