高时间分辨率和不对称的空间激活是脑内情绪过程的基本特征。为了学习EEG的时间动态性和空间不对称性,以实现准确和广义的情感识别,Yi Ding等人提出了一种多尺度卷积神经网络TSception,可以从EEG中对情感进行分类。Tsception由动态时间、非对称空间和高级融合层组成,同时学习时间和信道维度的区分表示。动态时域层由多尺度1D卷积核组成,其长度与EEG的采样率相关,其学习EEG的动态时间和频率表示。非对称空间层利用情绪的非对称EEG模式,学习有区别的全局和半球表示。原文代码可在以下网址获得:https://github.com/yi-ding-cs/TSception,本文复现完整代码可在下面获得:https://github.com/ruix6/tsception
关于模型结构的相关公式推理可以参考原文,本文不详细展开,下图是模型的具体结构:
熟悉经典深度学习模型的同学应该能一眼看出来TSception的设计灵感来自Inception模型。结合脑电信号的特点,TSception分四步实现对脑电信号的计算。
代码如下:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, AveragePooling2D, Flatten, Dense, Dropout, BatchNormalization, concatenate, LeakyReLU
# 定义时域卷积块
def conv_block(input, out_chan, kernel, step, pool):
x = Conv2D(out_chan, kernel, strides=step)(input)# padding='same', use_bias=False
x = LeakyReLU()(x)
x = AveragePooling2D(pool_size=(1, pool), strides=(1, pool))(x)
return x
def Tsception(num_classes, Chans, Samples, sampling_rate, num_T, num_S, hidden, dropout_rate, pool=8):
'''
input_size: 输入数据的维度,(chans, samples, 1)
'''
inception_window = [0.5, 0.25, 0.125]
# 定义输入层
input = Input(shape=(Chans, Samples, 1))
# 定义时域卷积层
x1 = conv_block(input, num_T, (1, int(sampling_rate * inception_window[0])), 1, pool)
x2 = conv_block(input, num_T, (1, int(sampling_rate * inception_window[1])), 1, pool)
x3 = conv_block(input, num_T, (1, int(sampling_rate * inception_window[2])), 1, pool)
# 在height维度上进行拼接
x = concatenate([x1, x2, x3], axis=2)
x = BatchNormalization()(x)
# 定义空域卷积层
y1 = conv_block(x, num_S, (Chans, 1), (Chans, 1), int(pool*0.25))
y2 = conv_block(x, num_S, (int(Chans*0.5), 1), (int(Chans*0.5), 1), int(pool*0.25))
# 在width维度上进行拼接
y = concatenate([y1, y2], axis=1)
y = BatchNormalization()(y)
# 定义fusion_layer
z = conv_block(y, num_S, (3, 1), (3, 1), 4)
z = BatchNormalization()(z)
# 定义全局平均池化层
z = AveragePooling2D(pool_size=(1, z.shape[2]))(z)
z = Flatten()(z)
# 全连接层
z = Dense(hidden, activation='relu')(z)# , use_bias=False
z = Dropout(dropout_rate)(z)
z = Dense(num_classes, activation='softmax')(z)# , use_bias=False
return Model(inputs=input, outputs=z)
参照原文的各个超参数,引用方式为:
if __name__ == '__main__':
model = Tsception(num_classes=2, Chans=28, Samples=512, sampling_rate=128, num_T=15, num_S=15, hidden=32, dropout_rate=0.5)
model.summary()
最后的输出为:
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 28, 512, 1) 0 []
]
conv2d (Conv2D) (None, 28, 449, 15) 975 ['input_1[0][0]']
conv2d_1 (Conv2D) (None, 28, 481, 15) 495 ['input_1[0][0]']
conv2d_2 (Conv2D) (None, 28, 497, 15) 255 ['input_1[0][0]']
leaky_re_lu (LeakyReLU) (None, 28, 449, 15) 0 ['conv2d[0][0]']
leaky_re_lu_1 (LeakyReLU) (None, 28, 481, 15) 0 ['conv2d_1[0][0]']
leaky_re_lu_2 (LeakyReLU) (None, 28, 497, 15) 0 ['conv2d_2[0][0]']
average_pooling2d (AveragePool (None, 28, 56, 15) 0 ['leaky_re_lu[0][0]']
ing2D)
average_pooling2d_1 (AveragePo (None, 28, 60, 15) 0 ['leaky_re_lu_1[0][0]']
oling2D)
average_pooling2d_2 (AveragePo (None, 28, 62, 15) 0 ['leaky_re_lu_2[0][0]']
oling2D)
concatenate (Concatenate) (None, 28, 178, 15) 0 ['average_pooling2d[0][0]',
'average_pooling2d_1[0][0]',
'average_pooling2d_2[0][0]']
batch_normalization (BatchNorm (None, 28, 178, 15) 60 ['concatenate[0][0]']
alization)
conv2d_3 (Conv2D) (None, 1, 178, 15) 6315 ['batch_normalization[0][0]']
conv2d_4 (Conv2D) (None, 2, 178, 15) 3165 ['batch_normalization[0][0]']
leaky_re_lu_3 (LeakyReLU) (None, 1, 178, 15) 0 ['conv2d_3[0][0]']
leaky_re_lu_4 (LeakyReLU) (None, 2, 178, 15) 0 ['conv2d_4[0][0]']
average_pooling2d_3 (AveragePo (None, 1, 89, 15) 0 ['leaky_re_lu_3[0][0]']
oling2D)
average_pooling2d_4 (AveragePo (None, 2, 89, 15) 0 ['leaky_re_lu_4[0][0]']
oling2D)
concatenate_1 (Concatenate) (None, 3, 89, 15) 0 ['average_pooling2d_3[0][0]',
'average_pooling2d_4[0][0]']
batch_normalization_1 (BatchNo (None, 3, 89, 15) 60 ['concatenate_1[0][0]']
rmalization)
conv2d_5 (Conv2D) (None, 1, 89, 15) 690 ['batch_normalization_1[0][0]']
leaky_re_lu_5 (LeakyReLU) (None, 1, 89, 15) 0 ['conv2d_5[0][0]']
average_pooling2d_5 (AveragePo (None, 1, 22, 15) 0 ['leaky_re_lu_5[0][0]']
oling2D)
batch_normalization_2 (BatchNo (None, 1, 22, 15) 60 ['average_pooling2d_5[0][0]']
rmalization)
average_pooling2d_6 (AveragePo (None, 1, 1, 15) 0 ['batch_normalization_2[0][0]']
oling2D)
flatten (Flatten) (None, 15) 0 ['average_pooling2d_6[0][0]']
dense (Dense) (None, 32) 512 ['flatten[0][0]']
dropout (Dropout) (None, 32) 0 ['dense[0][0]']
dense_1 (Dense) (None, 2) 66 ['dropout[0][0]']
==================================================================================================
Total params: 12,653
Trainable params: 12,563
Non-trainable params: 90
__________________________________________________________________________________________________
原文的作者并没有对模型进行更加具体的调参,事实上,输出的全连接层的神经元设置为32应该是意义不大的,效果可能不如直接链接到输出层,并且如果要考虑进一步缩小参数量的话,各个卷积层的偏置权重其实可以去除。在原文中,模型的效果表现与其它的卷积神经网络并没有统计学意义上的差别,但是,如果能够进一步调参的话,效果实际上要好很多。